diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..848678b35 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Apply Runic formatter +24375d13a4645398b66909acf82708d9e45a4766 \ No newline at end of file diff --git a/.github/workflows/PreCommit.yml b/.github/workflows/PreCommit.yml index 9a5cfd70c..a51ae0f7e 100644 --- a/.github/workflows/PreCommit.yml +++ b/.github/workflows/PreCommit.yml @@ -9,7 +9,7 @@ on: push: branches: - main - tags: ["*"] + tags: ['*'] pull_request: types: [opened, reopened, synchronize, ready_for_review] workflow_dispatch: @@ -21,7 +21,7 @@ jobs: - uses: actions/checkout@v5 - uses: julia-actions/setup-julia@v2 - uses: julia-actions/cache@v2 - - run: julia -e 'using Pkg; Pkg.add("JuliaFormatter")' - - uses: astral-sh/setup-uv@v6 + - run: julia -e 'using Pkg; Pkg.add("Runic")' + - uses: astral-sh/setup-uv@v7 - run: uv tool install pre-commit - run: pre-commit run --all-files --show-diff-on-failure --color always diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c3b01a95e..2e54e9959 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,8 +6,8 @@ repos: - id: check-added-large-files - id: check-merge-conflict - id: no-commit-to-branch -- repo: "https://github.com/domluna/JuliaFormatter.jl" - rev: "v2.1.6" # or whatever the desired release is +- repo: https://github.com/fredrikekre/runic-pre-commit + rev: v1.0.0 hooks: - - id: "julia-formatter" + - id: runic fail_fast: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d3aad94fa..69ebb3762 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,7 +9,7 @@ _Only the two maintainers are allowed to approve and merge pull requests._ Feel free to ping them if they do not answer within a week or so. Apart from the conditions above, this repository follows the [ColPrac](https://github.com/SciML/ColPrac) best practices and [Conventional Commits](https://www.conventionalcommits.org/en/). -Its code is formatted using [JuliaFormatter.jl](https://github.com/domluna/JuliaFormatter.jl) with [BlueStyle](https://github.com/JuliaDiff/BlueStyle). +Its code is formatted using [Runic.jl](https://github.com/fredrikekre/Runic.jl). As part of continuous integration, a set of formal tests is run using [pre-commit](https://pre-commit.com/). We invite you to install pre-commit so that these checks are performed locally before you open or update a pull request. You can refer to the [dev guide](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/dev/dev_guide/) for details on the package structure and the testing pipeline. diff --git a/DifferentiationInterface/README.md b/DifferentiationInterface/README.md index a1c7dd5ed..109e9bcf9 100644 --- a/DifferentiationInterface/README.md +++ b/DifferentiationInterface/README.md @@ -4,13 +4,13 @@ [![Build Status](https://github.com/JuliaDiff/DifferentiationInterface.jl/actions/workflows/Test.yml/badge.svg?branch=main)](https://github.com/JuliaDiff/DifferentiationInterface.jl/actions/workflows/Test.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/JuliaDiff/DifferentiationInterface.jl/branch/main/graph/badge.svg?flag=DI)](https://app.codecov.io/gh/JuliaDiff/DifferentiationInterface.jl) -[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/JuliaDiff/BlueStyle) +[![code style: runic](https://img.shields.io/badge/code_style-%E1%9A%B1%E1%9A%A2%E1%9A%BE%E1%9B%81%E1%9A%B2-black)](https://github.com/fredrikekre/Runic.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![DOI](https://zenodo.org/badge/740973714.svg)](https://zenodo.org/doi/10.5281/zenodo.11092033) -| Package | Docs | -|:----------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| -| DifferentiationInterface | [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/dev/) | +| Package | Docs | +| :--------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| DifferentiationInterface | [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/dev/) | | DifferentiationInterfaceTest | [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterfaceTest/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterfaceTest/dev/) | An interface to various automatic differentiation (AD) backends in Julia. diff --git a/DifferentiationInterface/docs/make.jl b/DifferentiationInterface/docs/make.jl index 797bcd423..49f39ad6c 100644 --- a/DifferentiationInterface/docs/make.jl +++ b/DifferentiationInterface/docs/make.jl @@ -21,11 +21,11 @@ readme_str = replace(readme_str, "> [!CAUTION]\n> " => "!!! warning\n ") write(joinpath(@__DIR__, "src", "index.md"), readme_str) makedocs(; - modules=[DifferentiationInterface], - authors="Guillaume Dalle, Adrian Hill", - sitename="DifferentiationInterface.jl", - format=Documenter.HTML(; assets=["assets/favicon.ico"]), - pages=[ + modules = [DifferentiationInterface], + authors = "Guillaume Dalle, Adrian Hill", + sitename = "DifferentiationInterface.jl", + format = Documenter.HTML(; assets = ["assets/favicon.ico"]), + pages = [ "Home" => "index.md", "Tutorials" => ["tutorials/basic.md", "tutorials/advanced.md"], "Explanation" => [ @@ -37,13 +37,13 @@ makedocs(; "api.md", "dev_guide.md", ], - plugins=[links], + plugins = [links], ) deploydocs(; - repo="github.com/JuliaDiff/DifferentiationInterface.jl", - devbranch="main", - dirname="DifferentiationInterface", - tag_prefix="DifferentiationInterface-", - push_preview=false, + repo = "github.com/JuliaDiff/DifferentiationInterface.jl", + devbranch = "main", + dirname = "DifferentiationInterface", + tag_prefix = "DifferentiationInterface-", + push_preview = false, ) diff --git a/DifferentiationInterface/docs/src/assets/logo.jl b/DifferentiationInterface/docs/src/assets/logo.jl index d21ef420e..36d716f3f 100644 --- a/DifferentiationInterface/docs/src/assets/logo.jl +++ b/DifferentiationInterface/docs/src/assets/logo.jl @@ -32,13 +32,13 @@ logo = begin # Define center points and corners of triangle center = Point(0, 4) - corners = ngon(center, 16, 3, -0.75π / 2; vertices=true) + corners = ngon(center, 16, 3, -0.75π / 2; vertices = true) # Draw three blended partials for (i, c) in enumerate(corners) b = blend(c, 0, center, 50, colors[i], black) setblend(b) - text("∂", c; valign=:middle, halign=:center, angle=0.68π + 2π / 3 * i) + text("∂", c; valign = :middle, halign = :center, angle = 0.68π + 2π / 3 * i) end finish() diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 07f74edbb..8ec2ab649 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -1,46 +1,46 @@ ## Pullback -struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} +struct ChainRulesPullbackPrepSamePoint{SIG, Y, PB} <: DI.PullbackPrep{SIG} _sig::Val{SIG} y::Y pb::PB end function DI.prepare_pullback_nokwarg( - strict::Val, - f, - backend::AutoReverseChainRules, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}; -) where {C} + strict::Val, + f, + backend::AutoReverseChainRules, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C} + ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( - f, - prep::DI.NoPullbackPrep, - backend::AutoReverseChainRules, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}; -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoReverseChainRules, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C} + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) + _sig = DI.signature(f, backend, x, ty, contexts...; strict = DI.is_strict(prep)) rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) return ChainRulesPullbackPrepSamePoint(_sig, y, pb) end function DI.value_and_pullback( - f, - prep::DI.NoPullbackPrep, - backend::AutoReverseChainRules, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoReverseChainRules, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -51,13 +51,13 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f, - prep::ChainRulesPullbackPrepSamePoint, - backend::AutoReverseChainRules, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::ChainRulesPullbackPrepSamePoint, + backend::AutoReverseChainRules, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; y, pb) = prep tx = map(ty) do dy @@ -67,13 +67,13 @@ function DI.value_and_pullback( end function DI.pullback( - f, - prep::ChainRulesPullbackPrepSamePoint, - backend::AutoReverseChainRules, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::ChainRulesPullbackPrepSamePoint, + backend::AutoReverseChainRules, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; pb) = prep tx = map(ty) do dy diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index bd286dc34..e784e4d96 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -11,15 +11,15 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward function DI.prepare_pushforward_nokwarg( - strict::Val, f, backend::AutoDiffractor, x, tx::NTuple -) + strict::Val, f, backend::AutoDiffractor, x, tx::NTuple + ) _sig = DI.signature(f, backend, x, tx; strict) return DI.NoPushforwardPrep(_sig) end function DI.pushforward( - f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple -) + f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple + ) DI.check_prep(f, prep, backend, x, tx) ty = map(tx) do dx # code copied from Diffractor.jl @@ -31,8 +31,8 @@ function DI.pushforward( end function DI.value_and_pushforward( - f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple -) + f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple + ) DI.check_prep(f, prep, backend, x, tx) return f(x), DI.pushforward(f, prep, backend, x, tx) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 3dd70bbc7..98ecc0db2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -1,19 +1,19 @@ ## Pushforward -struct EnzymeOneArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG} +struct EnzymeOneArgPushforwardPrep{SIG, DF, DC} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} df::DF context_shadows::DC end function DI.prepare_pushforward_nokwarg( - strict::Val, - f::F, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}; -) where {F,C,B} + strict::Val, + f::F, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C} + ) where {F, C, B} _sig = DI.signature(f, backend, x, tx, contexts...; strict) df = function_shadow(f, backend, Val(B)) mode = forward_withprimal(backend) @@ -22,13 +22,13 @@ function DI.prepare_pushforward_nokwarg( end function DI.value_and_pushforward( - f::F, - prep::EnzymeOneArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::EnzymeOneArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; df, context_shadows) = prep mode = forward_withprimal(backend) @@ -41,13 +41,13 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward( - f::F, - prep::EnzymeOneArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f::F, + prep::EnzymeOneArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; df, context_shadows) = prep mode = forward_withprimal(backend) @@ -59,13 +59,13 @@ function DI.value_and_pushforward( end function DI.pushforward( - f::F, - prep::EnzymeOneArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::EnzymeOneArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; df, context_shadows) = prep mode = forward_noprimal(backend) @@ -78,13 +78,13 @@ function DI.pushforward( end function DI.pushforward( - f::F, - prep::EnzymeOneArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f::F, + prep::EnzymeOneArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; df, context_shadows) = prep mode = forward_noprimal(backend) @@ -96,14 +96,14 @@ function DI.pushforward( end function DI.value_and_pushforward!( - f::F, - ty::NTuple, - prep::EnzymeOneArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + ty::NTuple, + prep::EnzymeOneArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, tx, contexts...) # dy cannot be passed anyway y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...) @@ -112,14 +112,14 @@ function DI.value_and_pushforward!( end function DI.pushforward!( - f::F, - ty::NTuple, - prep::EnzymeOneArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + ty::NTuple, + prep::EnzymeOneArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, tx, contexts...) # dy cannot be passed anyway new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...) @@ -129,7 +129,7 @@ end ## Gradient -struct EnzymeForwardGradientPrep{SIG,B,DF,DC,O} <: DI.GradientPrep{SIG} +struct EnzymeForwardGradientPrep{SIG, B, DF, DC, O} <: DI.GradientPrep{SIG} _sig::Val{SIG} _valB::Val{B} df::DF @@ -138,12 +138,12 @@ struct EnzymeForwardGradientPrep{SIG,B,DF,DC,O} <: DI.GradientPrep{SIG} end function DI.prepare_gradient_nokwarg( - strict::Val, - f::F, - backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}; -) where {F,C} + strict::Val, + f::F, + backend::AutoEnzyme{<:ForwardMode, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C} + ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) valB = to_val(DI.pick_batchsize(backend, x)) df = function_shadow(f, backend, valB) @@ -154,61 +154,61 @@ function DI.prepare_gradient_nokwarg( end function DI.gradient( - f::F, - prep::EnzymeForwardGradientPrep{SIG,B}, - backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,SIG,B,C} + f::F, + prep::EnzymeForwardGradientPrep{SIG, B}, + backend::AutoEnzyme{<:ForwardMode, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, SIG, B, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows, basis_shadows) = prep mode = forward_noprimal(backend) f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) derivs = gradient( - mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows + mode, f_and_df, x, annotated_contexts...; chunk = Val(B), shadows = basis_shadows ) return first(derivs) end function DI.value_and_gradient( - f::F, - prep::EnzymeForwardGradientPrep{SIG,B}, - backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,SIG,B,C} + f::F, + prep::EnzymeForwardGradientPrep{SIG, B}, + backend::AutoEnzyme{<:ForwardMode, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, SIG, B, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows, basis_shadows) = prep mode = forward_withprimal(backend) f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) (; derivs, val) = gradient( - mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows + mode, f_and_df, x, annotated_contexts...; chunk = Val(B), shadows = basis_shadows ) return val, first(derivs) end function DI.gradient!( - f::F, - grad, - prep::EnzymeForwardGradientPrep{SIG,B}, - backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,SIG,B,C} + f::F, + grad, + prep::EnzymeForwardGradientPrep{SIG, B}, + backend::AutoEnzyme{<:ForwardMode, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, SIG, B, C} DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end function DI.value_and_gradient!( - f::F, - grad, - prep::EnzymeForwardGradientPrep{SIG,B}, - backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,SIG,B,C} + f::F, + grad, + prep::EnzymeForwardGradientPrep{SIG, B}, + backend::AutoEnzyme{<:ForwardMode, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, SIG, B, C} DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -216,7 +216,7 @@ end ## Jacobian -struct EnzymeForwardOneArgJacobianPrep{SIG,B,DF,DC,O} <: DI.JacobianPrep{SIG} +struct EnzymeForwardOneArgJacobianPrep{SIG, B, DF, DC, O} <: DI.JacobianPrep{SIG} _sig::Val{SIG} _valB::Val{B} df::DF @@ -226,12 +226,12 @@ struct EnzymeForwardOneArgJacobianPrep{SIG,B,DF,DC,O} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian_nokwarg( - strict::Val, - f::F, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}; -) where {F,C} + strict::Val, + f::F, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C} + ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) valB = to_val(DI.pick_batchsize(backend, x)) @@ -245,63 +245,63 @@ function DI.prepare_jacobian_nokwarg( end function DI.jacobian( - f::F, - prep::EnzymeForwardOneArgJacobianPrep{SIG,B}, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,SIG,B,C} + f::F, + prep::EnzymeForwardOneArgJacobianPrep{SIG, B}, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, SIG, B, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows, basis_shadows, output_length) = prep mode = forward_noprimal(backend) f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) derivs = jacobian( - mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows + mode, f_and_df, x, annotated_contexts...; chunk = Val(B), shadows = basis_shadows ) jac_tensor = first(derivs) return maybe_reshape(jac_tensor, output_length, length(x)) end function DI.value_and_jacobian( - f::F, - prep::EnzymeForwardOneArgJacobianPrep{SIG,B}, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,SIG,B,C} + f::F, + prep::EnzymeForwardOneArgJacobianPrep{SIG, B}, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, SIG, B, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows, basis_shadows, output_length) = prep mode = forward_withprimal(backend) f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) (; derivs, val) = jacobian( - mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows + mode, f_and_df, x, annotated_contexts...; chunk = Val(B), shadows = basis_shadows ) jac_tensor = first(derivs) return val, maybe_reshape(jac_tensor, output_length, length(x)) end function DI.jacobian!( - f::F, - jac, - prep::EnzymeForwardOneArgJacobianPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,C} + f::F, + jac, + prep::EnzymeForwardOneArgJacobianPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end function DI.value_and_jacobian!( - f::F, - jac, - prep::EnzymeForwardOneArgJacobianPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,C} + f::F, + jac, + prep::EnzymeForwardOneArgJacobianPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index f0d2a2d91..bd8bdc04b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -1,20 +1,20 @@ ## Pushforward -struct EnzymeTwoArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG} +struct EnzymeTwoArgPushforwardPrep{SIG, DF, DC} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} df!::DF context_shadows::DC end function DI.prepare_pushforward_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}; -) where {F,B,C} + strict::Val, + f!::F, + y, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C} + ) where {F, B, C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) df! = function_shadow(f!, backend, Val(B)) mode = forward_noprimal(backend) @@ -23,14 +23,14 @@ function DI.prepare_pushforward_nokwarg( end function DI.value_and_pushforward( - f!::F, - y, - prep::EnzymeTwoArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::EnzymeTwoArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; df!, context_shadows) = prep mode = forward_noprimal(backend) @@ -45,14 +45,14 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward( - f!::F, - y, - prep::EnzymeTwoArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f!::F, + y, + prep::EnzymeTwoArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; df!, context_shadows) = prep mode = forward_noprimal(backend) @@ -66,29 +66,29 @@ function DI.value_and_pushforward( end function DI.pushforward( - f!::F, - y, - prep::EnzymeTwoArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::EnzymeTwoArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) _, ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...) return ty end function DI.value_and_pushforward!( - f!::F, - y, - ty::NTuple{B}, - prep::EnzymeTwoArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f!::F, + y, + ty::NTuple{B}, + prep::EnzymeTwoArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; df!, context_shadows) = prep mode = forward_noprimal(backend) @@ -101,15 +101,15 @@ function DI.value_and_pushforward!( end function DI.pushforward!( - f!::F, - y, - ty::NTuple, - prep::EnzymeTwoArgPushforwardPrep, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + ty::NTuple, + prep::EnzymeTwoArgPushforwardPrep, + backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) return ty diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl index 08918f618..41dad3588 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl @@ -11,29 +11,30 @@ function __init__() # robust against internal changes condition = ( isdefined(Enzyme, :Compiler) && - Enzyme.Compiler isa Module && - isdefined(Enzyme.Compiler, :EnzymeError) && - Enzyme.Compiler.EnzymeError isa DataType + Enzyme.Compiler isa Module && + isdefined(Enzyme.Compiler, :EnzymeError) && + Enzyme.Compiler.EnzymeError isa DataType ) condition || return nothing # see https://github.com/JuliaLang/julia/issues/58367 for why this isn't easier - for n in names(Enzyme.Compiler; all=true) + for n in names(Enzyme.Compiler; all = true) T = getfield(Enzyme.Compiler, n) if T isa DataType && T <: Enzyme.Compiler.EnzymeError # robust against internal changes Base.Experimental.register_error_hint(T) do io, exc if occursin("EnzymeMutabilityException", string(nameof(T))) - printstyled(io, HINT_START("function_annotation"); bold=true) + printstyled(io, HINT_START("function_annotation"); bold = true) printstyled( io, "\n\n\tAutoEnzyme(; function_annotation=Enzyme.Duplicated)"; - color=:cyan, - bold=true, + color = :cyan, + bold = true, ) - printstyled(io, HINT_END; italic=true) + printstyled(io, HINT_END; italic = true) end # EnzymeRuntimeActivityError is no longer a concrete type since https://github.com/EnzymeAD/Enzyme.jl/pull/2555 (now a UnionAll) so we cannot define a hint end end end + return end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 67b3989f0..c81c5cd44 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -1,10 +1,10 @@ function seeded_autodiff_thunk( - rmode::ReverseModeSplit{ReturnPrimal}, - dresult, - f::FA, - ::Type{RA}, - args::Vararg{Annotation,N}, -) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N} + rmode::ReverseModeSplit{ReturnPrimal}, + dresult, + f::FA, + ::Type{RA}, + args::Vararg{Annotation, N}, + ) where {ReturnPrimal, FA <: Annotation, RA <: Annotation, N} forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...) tape, result, shadow_result = forward(f, args...) if RA <: Active @@ -21,12 +21,12 @@ function seeded_autodiff_thunk( end function batch_seeded_autodiff_thunk( - rmode::ReverseModeSplit{ReturnPrimal}, - dresults::NTuple{B}, - f::FA, - ::Type{RA}, - args::Vararg{Annotation,N}, -) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N} + rmode::ReverseModeSplit{ReturnPrimal}, + dresults::NTuple{B}, + f::FA, + ::Type{RA}, + args::Vararg{Annotation, N}, + ) where {ReturnPrimal, B, FA <: Annotation, RA <: Annotation, N} rmode_rightwidth = ReverseSplitWidth(rmode, Val(B)) forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...) tape, result, shadow_results = forward(f, args...) @@ -47,7 +47,7 @@ end ## Pullback -struct EnzymeReverseOneArgPullbackPrep{SIG,DF,DC,Y} <: DI.PullbackPrep{SIG} +struct EnzymeReverseOneArgPullbackPrep{SIG, DF, DC, Y} <: DI.PullbackPrep{SIG} _sig::Val{SIG} df::DF context_shadows::DC @@ -55,13 +55,13 @@ struct EnzymeReverseOneArgPullbackPrep{SIG,DF,DC,Y} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback_nokwarg( - strict::Val, - f::F, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{B}, - contexts::Vararg{DI.Context,C}; -) where {F,B,C} + strict::Val, + f::F, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{B}, + contexts::Vararg{DI.Context, C} + ) where {F, B, C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) df = function_shadow(f, backend, Val(B)) mode = reverse_split_withprimal(backend) @@ -73,13 +73,13 @@ end ### Out-of-place function DI.value_and_pullback( - f::F, - prep::EnzymeReverseOneArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::EnzymeReverseOneArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; df, context_shadows, y_example) = prep mode = reverse_split_withprimal(backend) @@ -100,13 +100,13 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f::F, - prep::EnzymeReverseOneArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f::F, + prep::EnzymeReverseOneArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; df, context_shadows, y_example) = prep mode = reverse_split_withprimal(backend) @@ -127,13 +127,13 @@ function DI.value_and_pullback( end function DI.pullback( - f::F, - prep::EnzymeReverseOneArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::EnzymeReverseOneArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, ty, contexts...) return last(DI.value_and_pullback(f, prep, backend, x, ty, contexts...)) end @@ -141,14 +141,14 @@ end ### In-place function DI.value_and_pullback!( - f::F, - tx::NTuple{1}, - prep::EnzymeReverseOneArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + tx::NTuple{1}, + prep::EnzymeReverseOneArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; df, context_shadows, y_example) = prep mode = reverse_split_withprimal(backend) @@ -164,14 +164,14 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - f::F, - tx::NTuple{B}, - prep::EnzymeReverseOneArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f::F, + tx::NTuple{B}, + prep::EnzymeReverseOneArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; df, context_shadows, y_example) = prep mode = reverse_split_withprimal(backend) @@ -186,33 +186,33 @@ function DI.value_and_pullback!( end function DI.pullback!( - f::F, - tx::NTuple, - prep::EnzymeReverseOneArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + tx::NTuple, + prep::EnzymeReverseOneArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, ty, contexts...) return last(DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)) end ## Gradient -struct EnzymeGradientPrep{SIG,DF,DC} <: DI.GradientPrep{SIG} +struct EnzymeGradientPrep{SIG, DF, DC} <: DI.GradientPrep{SIG} _sig::Val{SIG} df::DF context_shadows::DC end function DI.prepare_gradient_nokwarg( - strict::Val, - f::F, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, + f::F, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) df = function_shadow(f, backend, Val(1)) mode = reverse_withprimal(backend) @@ -223,12 +223,12 @@ end ### Enzyme gradient API (only constants) function DI.gradient( - f::F, - prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,C} + f::F, + prep::EnzymeGradientPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows) = prep mode = reverse_noprimal(backend) @@ -239,12 +239,12 @@ function DI.gradient( end function DI.value_and_gradient( - f::F, - prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{DI.Constant,C}, -) where {F,C} + f::F, + prep::EnzymeGradientPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}, <:Union{Nothing, Const}}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows) = prep mode = reverse_withprimal(backend) @@ -255,12 +255,12 @@ function DI.value_and_gradient( end function DI.gradient!( - f::F, - grad, - prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, - x, -) where {F} + f::F, + grad, + prep::EnzymeGradientPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}, <:Union{Nothing, Const}}, + x, + ) where {F} DI.check_prep(f, prep, backend, x) (; df) = prep mode = reverse_noprimal(backend) @@ -270,12 +270,12 @@ function DI.gradient!( end function DI.value_and_gradient!( - f::F, - grad, - prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, - x, -) where {F} + f::F, + grad, + prep::EnzymeGradientPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}, <:Union{Nothing, Const}}, + x, + ) where {F} DI.check_prep(f, prep, backend, x) (; df) = prep mode = reverse_withprimal(backend) @@ -287,12 +287,12 @@ end ### Generic function DI.gradient( - f::F, - prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::EnzymeGradientPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows) = prep mode = reverse_noprimal(backend) @@ -312,12 +312,12 @@ function DI.gradient( end function DI.value_and_gradient( - f::F, - prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::EnzymeGradientPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows) = prep mode = reverse_withprimal(backend) @@ -337,13 +337,13 @@ function DI.value_and_gradient( end function DI.gradient!( - f::F, - grad, - prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + grad, + prep::EnzymeGradientPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows) = prep mode = reverse_noprimal(backend) @@ -355,13 +355,13 @@ function DI.gradient!( end function DI.value_and_gradient!( - f::F, - grad, - prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + grad, + prep::EnzymeGradientPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) (; df, context_shadows) = prep mode = reverse_withprimal(backend) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 18c4d9c68..ddf5f774c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -1,6 +1,6 @@ ## Pullback -struct EnzymeReverseTwoArgPullbackPrep{SIG,DF,DC,TY} <: DI.PullbackPrep{SIG} +struct EnzymeReverseTwoArgPullbackPrep{SIG, DF, DC, TY} <: DI.PullbackPrep{SIG} _sig::Val{SIG} df!::DF context_shadows::DC @@ -8,14 +8,14 @@ struct EnzymeReverseTwoArgPullbackPrep{SIG,DF,DC,TY} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{B}, - contexts::Vararg{DI.Context,C}; -) where {F,B,C} + strict::Val, + f!::F, + y, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{B}, + contexts::Vararg{DI.Context, C} + ) where {F, B, C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) df! = function_shadow(f!, backend, Val(B)) mode = reverse_noprimal(backend) @@ -25,14 +25,14 @@ function DI.prepare_pullback_nokwarg( end function DI.value_and_pullback( - f!::F, - y, - prep::EnzymeReverseTwoArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x::Number, - ty::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::EnzymeReverseTwoArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x::Number, + ty::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) (; df!, context_shadows, ty_copy) = prep copyto!(only(ty_copy), only(ty)) @@ -49,14 +49,14 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f!::F, - y, - prep::EnzymeReverseTwoArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x::Number, - ty::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f!::F, + y, + prep::EnzymeReverseTwoArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x::Number, + ty::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) (; df!, context_shadows, ty_copy) = prep foreach(copyto!, ty_copy, ty) @@ -73,14 +73,14 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f!::F, - y, - prep::EnzymeReverseTwoArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::EnzymeReverseTwoArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) (; df!, context_shadows, ty_copy) = prep copyto!(only(ty_copy), only(ty)) @@ -96,14 +96,14 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f!::F, - y, - prep::EnzymeReverseTwoArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f!::F, + y, + prep::EnzymeReverseTwoArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) (; df!, context_shadows, ty_copy) = prep foreach(copyto!, ty_copy, ty) @@ -119,15 +119,15 @@ function DI.value_and_pullback( end function DI.value_and_pullback!( - f!::F, - y, - tx::NTuple{1}, - prep::EnzymeReverseTwoArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + tx::NTuple{1}, + prep::EnzymeReverseTwoArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) (; df!, context_shadows, ty_copy) = prep copyto!(only(ty_copy), only(ty)) @@ -144,15 +144,15 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - f!::F, - y, - tx::NTuple{B}, - prep::EnzymeReverseTwoArgPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x, - ty::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f!::F, + y, + tx::NTuple{B}, + prep::EnzymeReverseTwoArgPullbackPrep, + backend::AutoEnzyme{<:Union{ReverseMode, Nothing}}, + x, + ty::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) (; df!, context_shadows, ty_copy) = prep foreach(copyto!, ty_copy, ty) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 991796bb1..57c654cf5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -17,17 +17,17 @@ to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B) ## Annotations -function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M,Nothing}, ::Val{B}) where {F,M,B} +function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M, Nothing}, ::Val{B}) where {F, M, B} return f end -function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M,<:Const}, ::Val{B}) where {F,M,B} +function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M, <:Const}, ::Val{B}) where {F, M, B} return Const(f) end function get_f_and_df_prepared!( - df, f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B} -) where {F,M,B} + df, f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B} + ) where {F, M, B} #= It is not obvious why we don't need a `make_zero` here, in the case of mutable constant data in `f`. - In forward mode, `df` is never incremented if `f` is not mutated, so it remains equal to its initial value of `0`. @@ -41,12 +41,12 @@ function get_f_and_df_prepared!( end function function_shadow( - ::F, ::AutoEnzyme{M,<:Union{Const,Nothing}}, ::Val{B} -) where {M,B,F} + ::F, ::AutoEnzyme{M, <:Union{Const, Nothing}}, ::Val{B} + ) where {M, B, F} return nothing end -function function_shadow(f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}) where {F,M,B} +function function_shadow(f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B}) where {F, M, B} if B == 1 return make_zero(f) else @@ -54,7 +54,7 @@ function function_shadow(f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}) where end end -force_annotation(f::F) where {F<:Annotation} = f +force_annotation(f::F) where {F <: Annotation} = f force_annotation(f::F) where {F} = Const(f) function _shadow(::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.Constant) where {B} @@ -71,11 +71,11 @@ function _shadow(::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.Cache) where {B} end function _shadow( - ::AutoEnzyme, mode::Mode, valB::Val{B}, c_wrapped::DI.ConstantOrCache -) where {B} + ::AutoEnzyme, mode::Mode, valB::Val{B}, c_wrapped::DI.ConstantOrCache + ) where {B} c = DI.unwrap(c_wrapped) IA = guess_activity(typeof(c), mode) - if IA <: Const + return if IA <: Const nothing else if B == 1 @@ -87,18 +87,18 @@ function _shadow( end function _shadow( - backend::AutoEnzyme{M,<:Union{Const,Nothing}}, - ::Mode, - ::Val{B}, - c_wrapped::DI.FunctionContext, -) where {M,B} + backend::AutoEnzyme{M, <:Union{Const, Nothing}}, + ::Mode, + ::Val{B}, + c_wrapped::DI.FunctionContext, + ) where {M, B} f = DI.unwrap(c_wrapped) return function_shadow(f, backend, Val(B)) end function make_context_shadows( - backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context,C} -) where {B,C} + backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context, C} + ) where {B, C} context_shadows = map(contexts) do c_wrapped _shadow(backend, mode, Val(B), c_wrapped) end @@ -120,8 +120,8 @@ function _translate_prepared!(dc, c_wrapped::DI.Cache, ::Val{B}) where {B} end function _translate_prepared!( - dc, c_wrapped::Union{DI.ConstantOrCache,DI.FunctionContext}, ::Val{B} -) where {B} + dc, c_wrapped::Union{DI.ConstantOrCache, DI.FunctionContext}, ::Val{B} + ) where {B} #= It is not obvious why we don't need a `make_zero` here, in the case of mutable constant contexts. - In forward mode, `dc` is never incremented because `c` is not mutated, so it remains equal to its initial value of `0`. @@ -140,8 +140,8 @@ function _translate_prepared!( end function translate_prepared!( - context_shadows::NTuple{C,Any}, contexts::NTuple{C,DI.Context}, ::Val{B} -) where {B,C} + context_shadows::NTuple{C, Any}, contexts::NTuple{C, DI.Context}, ::Val{B} + ) where {B, C} new_contexts = map(context_shadows, contexts) do dc, c_wrapped _translate_prepared!(dc, c_wrapped, Val(B)) end @@ -170,10 +170,10 @@ function reverse_split_withprimal(backend::AutoEnzyme{Nothing}) return set_err(ReverseSplitWithPrimal, backend) end -function set_err(mode::Mode, backend::AutoEnzyme{<:Any,Nothing}) +function set_err(mode::Mode, backend::AutoEnzyme{<:Any, Nothing}) return EnzymeCore.set_err_if_func_written(mode) end -set_err(mode::Mode, backend::AutoEnzyme{<:Any,<:Annotation}) = mode +set_err(mode::Mode, backend::AutoEnzyme{<:Any, <:Annotation}) = mode function maybe_reshape(A::AbstractMatrix, m, n) @assert size(A) == (m, n) @@ -187,9 +187,9 @@ end annotate(::Type{Active{T}}, x, dx) where {T} = Active(x) annotate(::Type{Duplicated{T}}, x, dx) where {T} = Duplicated(x, dx) -function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B} +function annotate(::Type{BatchDuplicated{T, B}}, x, tx::NTuple{B}) where {T, B} return BatchDuplicated(x, tx) end -batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T} -batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B} +batchify_activity(::Type{Active{T}}, ::Val{B}) where {T, B} = Active{T} +batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T, B} = BatchDuplicated{T, B} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index 430cc95f9..326c89fd9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -24,8 +24,8 @@ myvec(x::AbstractArray) = vec(x) variablize(::Number, name::Symbol) = only(make_variables(name)) variablize(x::AbstractArray, name::Symbol) = make_variables(name, size(x)...) -function variablize(contexts::NTuple{C,DI.Context}) where {C} - map(enumerate(contexts)) do (k, c) +function variablize(contexts::NTuple{C, DI.Context}) where {C} + return map(enumerate(contexts)) do (k, c) variablize(DI.unwrap(c), Symbol("context$k")) end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 21bea0c7b..316d1611c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -1,6 +1,6 @@ ## Pushforward -struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardPrep{SIG} +struct FastDifferentiationOneArgPushforwardPrep{SIG, Y, E1, E1!} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} y_prototype::Y jvp_exe::E1 @@ -8,13 +8,13 @@ struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardP end function DI.prepare_pushforward_nokwarg( - strict::Val, - f, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) @@ -25,22 +25,22 @@ function DI.prepare_pushforward_nokwarg( y_vec_var = myvec(y_var) jv_vec_var, v_vec_var = jacobian_times_v(y_vec_var, x_vec_var) jvp_exe = make_function( - jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = false ) jvp_exe! = make_function( - jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = true ) return FastDifferentiationOneArgPushforwardPrep(_sig, y_prototype, jvp_exe, jvp_exe!) end function DI.pushforward( - f, - prep::FastDifferentiationOneArgPushforwardPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgPushforwardPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) ty = map(tx) do dx result = prep.jvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) @@ -54,14 +54,14 @@ function DI.pushforward( end function DI.pushforward!( - f, - ty::NTuple, - prep::FastDifferentiationOneArgPushforwardPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + ty::NTuple, + prep::FastDifferentiationOneArgPushforwardPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -71,48 +71,48 @@ function DI.pushforward!( end function DI.value_and_pushforward( - f, - prep::FastDifferentiationOneArgPushforwardPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgPushforwardPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.pushforward(f, prep, backend, x, tx, contexts...) + DI.pushforward(f, prep, backend, x, tx, contexts...) end function DI.value_and_pushforward!( - f, - ty::NTuple, - prep::FastDifferentiationOneArgPushforwardPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + ty::NTuple, + prep::FastDifferentiationOneArgPushforwardPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) + DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) end ## Pullback -struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} +struct FastDifferentiationOneArgPullbackPrep{SIG, E1, E1!} <: DI.PullbackPrep{SIG} _sig::Val{SIG} vjp_exe::E1 vjp_exe!::E1! end function DI.prepare_pullback_nokwarg( - strict::Val, - f, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -123,22 +123,22 @@ function DI.prepare_pullback_nokwarg( y_vec_var = myvec(y_var) vj_vec_var, v_vec_var = jacobian_transpose_v(y_vec_var, x_vec_var) vjp_exe = make_function( - vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = false ) vjp_exe! = make_function( - vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = true ) return FastDifferentiationOneArgPullbackPrep(_sig, vjp_exe, vjp_exe!) end function DI.pullback( - f, - prep::FastDifferentiationOneArgPullbackPrep, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgPullbackPrep, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) tx = map(ty) do dy result = prep.vjp_exe(myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) @@ -152,14 +152,14 @@ function DI.pullback( end function DI.pullback!( - f, - tx::NTuple, - prep::FastDifferentiationOneArgPullbackPrep, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + tx::NTuple, + prep::FastDifferentiationOneArgPullbackPrep, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -169,35 +169,35 @@ function DI.pullback!( end function DI.value_and_pullback( - f, - prep::FastDifferentiationOneArgPullbackPrep, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgPullbackPrep, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.pullback(f, prep, backend, x, ty, contexts...) + DI.pullback(f, prep, backend, x, ty, contexts...) end function DI.value_and_pullback!( - f, - tx::NTuple, - prep::FastDifferentiationOneArgPullbackPrep, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + tx::NTuple, + prep::FastDifferentiationOneArgPullbackPrep, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.pullback!(f, tx, prep, backend, x, ty, contexts...) + DI.pullback!(f, tx, prep, backend, x, ty, contexts...) end ## Derivative -struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePrep{SIG} +struct FastDifferentiationOneArgDerivativePrep{SIG, Y, E1, E1!} <: DI.DerivativePrep{SIG} _sig::Val{SIG} y_prototype::Y der_exe::E1 @@ -205,8 +205,8 @@ struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePre end function DI.prepare_derivative_nokwarg( - strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) @@ -217,18 +217,18 @@ function DI.prepare_derivative_nokwarg( context_vec_vars = map(myvec, context_vars) y_vec_var = myvec(y_var) der_vec_var = derivative(y_vec_var, x_var) - der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) - der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) + der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place = false) + der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place = true) return FastDifferentiationOneArgDerivativePrep(_sig, y_prototype, der_exe, der_exe!) end function DI.derivative( - f, - prep::FastDifferentiationOneArgDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) result = prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number @@ -239,54 +239,54 @@ function DI.derivative( end function DI.derivative!( - f, - der, - prep::FastDifferentiationOneArgDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der, + prep::FastDifferentiationOneArgDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...) return der end function DI.value_and_derivative( - f, - prep::FastDifferentiationOneArgDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.derivative(f, prep, backend, x, contexts...) + DI.derivative(f, prep, backend, x, contexts...) end function DI.value_and_derivative!( - f, - der, - prep::FastDifferentiationOneArgDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der, + prep::FastDifferentiationOneArgDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.derivative!(f, der, prep, backend, x, contexts...) + DI.derivative!(f, der, prep, backend, x, contexts...) end ## Gradient -struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} +struct FastDifferentiationOneArgGradientPrep{SIG, E1, E1!} <: DI.GradientPrep{SIG} _sig::Val{SIG} jac_exe::E1 jac_exe!::E1! end function DI.prepare_gradient_nokwarg( - strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -296,18 +296,18 @@ function DI.prepare_gradient_nokwarg( context_vec_vars = map(myvec, context_vars) y_vec_var = myvec(y_var) jac_var = jacobian(y_vec_var, x_vec_var) - jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) - jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) + jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place = false) + jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place = true) return FastDifferentiationOneArgGradientPrep(_sig, jac_exe, jac_exe!) end function DI.gradient( - f, - prep::FastDifferentiationOneArgGradientPrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgGradientPrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) grad_vec = @view jac[1, :] @@ -315,45 +315,45 @@ function DI.gradient( end function DI.gradient!( - f, - grad, - prep::FastDifferentiationOneArgGradientPrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::FastDifferentiationOneArgGradientPrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.jac_exe!(reshape(grad, 1, length(grad)), myvec(x), map(myvec_unwrap, contexts)...) return grad end function DI.value_and_gradient( - f, - prep::FastDifferentiationOneArgGradientPrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgGradientPrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end function DI.value_and_gradient!( - f, - grad, - prep::FastDifferentiationOneArgGradientPrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::FastDifferentiationOneArgGradientPrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.gradient!(f, grad, prep, backend, x, contexts...) + DI.gradient!(f, grad, prep, backend, x, contexts...) end ## Jacobian -struct FastDifferentiationOneArgJacobianPrep{SIG,Y,P,E1,E1!} <: DI.SparseJacobianPrep{SIG} +struct FastDifferentiationOneArgJacobianPrep{SIG, Y, P, E1, E1!} <: DI.SparseJacobianPrep{SIG} _sig::Val{SIG} y_prototype::Y sparsity::P @@ -362,12 +362,12 @@ struct FastDifferentiationOneArgJacobianPrep{SIG,Y,P,E1,E1!} <: DI.SparseJacobia end function DI.prepare_jacobian_nokwarg( - strict::Val, - f, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) @@ -384,65 +384,65 @@ function DI.prepare_jacobian_nokwarg( jac_var = jacobian(y_vec_var, x_vec_var) sparsity = nothing end - jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) - jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) + jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place = false) + jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place = true) return FastDifferentiationOneArgJacobianPrep( _sig, y_prototype, sparsity, jac_exe, jac_exe! ) end function DI.jacobian( - f, - prep::FastDifferentiationOneArgJacobianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgJacobianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) end function DI.jacobian!( - f, - jac, - prep::FastDifferentiationOneArgJacobianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::FastDifferentiationOneArgJacobianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return jac end function DI.value_and_jacobian( - f, - prep::FastDifferentiationOneArgJacobianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationOneArgJacobianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end function DI.value_and_jacobian!( - f, - jac, - prep::FastDifferentiationOneArgJacobianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::FastDifferentiationOneArgJacobianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.jacobian!(f, jac, prep, backend, x, contexts...) + DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Second derivative -struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <: - DI.SecondDerivativePrep{SIG} +struct FastDifferentiationAllocatingSecondDerivativePrep{SIG, Y, D, E2, E2!} <: + DI.SecondDerivativePrep{SIG} _sig::Val{SIG} y_prototype::Y derivative_prep::D @@ -451,8 +451,8 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <: end function DI.prepare_second_derivative_nokwarg( - strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) @@ -464,8 +464,8 @@ function DI.prepare_second_derivative_nokwarg( y_vec_var = myvec(y_var) der2_vec_var = derivative(y_vec_var, x_var, x_var) - der2_exe = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=false) - der2_exe! = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=true) + der2_exe = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place = false) + der2_exe! = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place = true) derivative_prep = DI.prepare_derivative_nokwarg(strict, f, backend, x, contexts...) return FastDifferentiationAllocatingSecondDerivativePrep( @@ -474,12 +474,12 @@ function DI.prepare_second_derivative_nokwarg( end function DI.second_derivative( - f, - prep::FastDifferentiationAllocatingSecondDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationAllocatingSecondDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) result = prep.der2_exe(myvec(x), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number @@ -490,25 +490,25 @@ function DI.second_derivative( end function DI.second_derivative!( - f, - der2, - prep::FastDifferentiationAllocatingSecondDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der2, + prep::FastDifferentiationAllocatingSecondDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.der2_exe!(myvec(der2), myvec(x), map(myvec_unwrap, contexts)...) return der2 end function DI.value_derivative_and_second_derivative( - f, - prep::FastDifferentiationAllocatingSecondDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationAllocatingSecondDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x, contexts...) der2 = DI.second_derivative(f, prep, backend, x, contexts...) @@ -516,14 +516,14 @@ function DI.value_derivative_and_second_derivative( end function DI.value_derivative_and_second_derivative!( - f, - der, - der2, - prep::FastDifferentiationAllocatingSecondDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der, + der2, + prep::FastDifferentiationAllocatingSecondDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x, contexts...) DI.second_derivative!(f, der2, prep, backend, x, contexts...) @@ -532,7 +532,7 @@ end ## HVP -struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG} +struct FastDifferentiationHVPPrep{SIG, E2, E2!, E1} <: DI.HVPPrep{SIG} sig::Val{SIG} hvp_exe::E2 hvp_exe!::E2! @@ -540,13 +540,13 @@ struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG} end function DI.prepare_hvp_nokwarg( - strict::Val, - f, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -556,10 +556,10 @@ function DI.prepare_hvp_nokwarg( context_vec_vars = map(myvec, context_vars) hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var) hvp_exe = make_function( - hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = false ) hvp_exe! = make_function( - hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = true ) gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...) @@ -567,13 +567,13 @@ function DI.prepare_hvp_nokwarg( end function DI.hvp( - f, - prep::FastDifferentiationHVPPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationHVPPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) tg = map(tx) do dx dg_vec = prep.hvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) @@ -583,14 +583,14 @@ function DI.hvp( end function DI.hvp!( - f, - tg::NTuple, - prep::FastDifferentiationHVPPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + tg::NTuple, + prep::FastDifferentiationHVPPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, tg) dx, dg = tx[b], tg[b] @@ -600,13 +600,13 @@ function DI.hvp!( end function DI.gradient_and_hvp( - f, - prep::FastDifferentiationHVPPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationHVPPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) tg = DI.hvp(f, prep, backend, x, tx, contexts...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) @@ -614,15 +614,15 @@ function DI.gradient_and_hvp( end function DI.gradient_and_hvp!( - f, - grad, - tg::NTuple, - prep::FastDifferentiationHVPPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + tg::NTuple, + prep::FastDifferentiationHVPPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hvp!(f, tg, prep, backend, x, tx, contexts...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) @@ -631,7 +631,7 @@ end ## Hessian -struct FastDifferentiationHessianPrep{SIG,G,P,E2,E2!} <: DI.SparseHessianPrep{SIG} +struct FastDifferentiationHessianPrep{SIG, G, P, E2, E2!} <: DI.SparseHessianPrep{SIG} _sig::Val{SIG} gradient_prep::G sparsity::P @@ -640,12 +640,12 @@ struct FastDifferentiationHessianPrep{SIG,G,P,E2,E2!} <: DI.SparseHessianPrep{SI end function DI.prepare_hessian_nokwarg( - strict::Val, - f, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -661,8 +661,8 @@ function DI.prepare_hessian_nokwarg( hess_var = hessian(y_var, x_vec_var) sparsity = nothing end - hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=false) - hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true) + hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place = false) + hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place = true) gradient_prep = DI.prepare_gradient_nokwarg( strict, f, dense_ad(backend), x, contexts... @@ -673,36 +673,36 @@ function DI.prepare_hessian_nokwarg( end function DI.hessian( - f, - prep::FastDifferentiationHessianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationHessianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return prep.hess_exe(myvec(x), map(myvec_unwrap, contexts)...) end function DI.hessian!( - f, - hess, - prep::FastDifferentiationHessianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + hess, + prep::FastDifferentiationHessianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.hess_exe!(hess, myvec(x), map(myvec_unwrap, contexts)...) return hess end function DI.value_gradient_and_hessian( - f, - prep::FastDifferentiationHessianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FastDifferentiationHessianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient( f, prep.gradient_prep, dense_ad(backend), x, contexts... @@ -712,14 +712,14 @@ function DI.value_gradient_and_hessian( end function DI.value_gradient_and_hessian!( - f, - grad, - hess, - prep::FastDifferentiationHessianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + hess, + prep::FastDifferentiationHessianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!( f, grad, prep.gradient_prep, dense_ad(backend), x, contexts... diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index 8a899a27a..67fb00149 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -1,20 +1,20 @@ ## Pushforward -struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} +struct FastDifferentiationTwoArgPushforwardPrep{SIG, E1, E1!} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} jvp_exe::E1 jvp_exe!::E1! end function DI.prepare_pushforward_nokwarg( - strict::Val, - f!, - y, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f!, + y, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -26,23 +26,23 @@ function DI.prepare_pushforward_nokwarg( y_vec_var = myvec(y_var) jv_vec_var, v_vec_var = jacobian_times_v(y_vec_var, x_vec_var) jvp_exe = make_function( - jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = false ) jvp_exe! = make_function( - jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = true ) return FastDifferentiationTwoArgPushforwardPrep(_sig, jvp_exe, jvp_exe!) end function DI.pushforward( - f!, - y, - prep::FastDifferentiationTwoArgPushforwardPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FastDifferentiationTwoArgPushforwardPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx reshape(prep.jvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...), size(y)) @@ -51,15 +51,15 @@ function DI.pushforward( end function DI.pushforward!( - f!, - y, - ty::NTuple, - prep::FastDifferentiationTwoArgPushforwardPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + ty::NTuple, + prep::FastDifferentiationTwoArgPushforwardPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -69,14 +69,14 @@ function DI.pushforward!( end function DI.value_and_pushforward( - f!, - y, - prep::FastDifferentiationTwoArgPushforwardPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FastDifferentiationTwoArgPushforwardPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -84,15 +84,15 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - f!, - y, - ty::NTuple, - prep::FastDifferentiationTwoArgPushforwardPrep, - backend::AutoFastDifferentiation, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + ty::NTuple, + prep::FastDifferentiationTwoArgPushforwardPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -101,21 +101,21 @@ end ## Pullback -struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} +struct FastDifferentiationTwoArgPullbackPrep{SIG, E1, E1!} <: DI.PullbackPrep{SIG} _sig::Val{SIG} vjp_exe::E1 vjp_exe!::E1! end function DI.prepare_pullback_nokwarg( - strict::Val, - f!, - y, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f!, + y, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -127,23 +127,23 @@ function DI.prepare_pullback_nokwarg( y_vec_var = myvec(y_var) vj_vec_var, v_vec_var = jacobian_transpose_v(y_vec_var, x_vec_var) vjp_exe = make_function( - vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = false ) vjp_exe! = make_function( - vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place = true ) return FastDifferentiationTwoArgPullbackPrep(_sig, vjp_exe, vjp_exe!) end function DI.pullback( - f!, - y, - prep::FastDifferentiationTwoArgPullbackPrep, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FastDifferentiationTwoArgPullbackPrep, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy result = prep.vjp_exe(myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) @@ -157,15 +157,15 @@ function DI.pullback( end function DI.pullback!( - f!, - y, - tx::NTuple, - prep::FastDifferentiationTwoArgPullbackPrep, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + tx::NTuple, + prep::FastDifferentiationTwoArgPullbackPrep, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -175,14 +175,14 @@ function DI.pullback!( end function DI.value_and_pullback( - f!, - y, - prep::FastDifferentiationTwoArgPullbackPrep, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FastDifferentiationTwoArgPullbackPrep, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = DI.pullback(f!, y, prep, backend, x, ty, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -190,15 +190,15 @@ function DI.value_and_pullback( end function DI.value_and_pullback!( - f!, - y, - tx::NTuple, - prep::FastDifferentiationTwoArgPullbackPrep, - backend::AutoFastDifferentiation, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + tx::NTuple, + prep::FastDifferentiationTwoArgPullbackPrep, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) DI.pullback!(f!, y, tx, prep, backend, x, ty, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -207,15 +207,15 @@ end ## Derivative -struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} +struct FastDifferentiationTwoArgDerivativePrep{SIG, E1, E1!} <: DI.DerivativePrep{SIG} _sig::Val{SIG} der_exe::E1 der_exe!::E1! end function DI.prepare_derivative_nokwarg( - strict::Val, f!, y, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f!, y, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -226,19 +226,19 @@ function DI.prepare_derivative_nokwarg( context_vec_vars = map(myvec, context_vars) y_vec_var = myvec(y_var) der_vec_var = derivative(y_vec_var, x_var) - der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) - der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) + der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place = false) + der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place = true) return FastDifferentiationTwoArgDerivativePrep(_sig, der_exe, der_exe!) end function DI.value_and_derivative( - f!, - y, - prep::FastDifferentiationTwoArgDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FastDifferentiationTwoArgDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) der = reshape(prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...), size(y)) @@ -246,14 +246,14 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!, - y, - der, - prep::FastDifferentiationTwoArgDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + der, + prep::FastDifferentiationTwoArgDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...) @@ -261,27 +261,27 @@ function DI.value_and_derivative!( end function DI.derivative( - f!, - y, - prep::FastDifferentiationTwoArgDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FastDifferentiationTwoArgDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) der = reshape(prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...), size(y)) return der end function DI.derivative!( - f!, - y, - der, - prep::FastDifferentiationTwoArgDerivativePrep, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + der, + prep::FastDifferentiationTwoArgDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...) return der @@ -289,7 +289,7 @@ end ## Jacobian -struct FastDifferentiationTwoArgJacobianPrep{SIG,P,E1,E1!} <: DI.SparseJacobianPrep{SIG} +struct FastDifferentiationTwoArgJacobianPrep{SIG, P, E1, E1!} <: DI.SparseJacobianPrep{SIG} _sig::Val{SIG} sparsity::P jac_exe::E1 @@ -297,13 +297,13 @@ struct FastDifferentiationTwoArgJacobianPrep{SIG,P,E1,E1!} <: DI.SparseJacobianP end function DI.prepare_jacobian_nokwarg( - strict::Val, - f!, - y, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f!, + y, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -320,19 +320,19 @@ function DI.prepare_jacobian_nokwarg( jac_var = jacobian(y_vec_var, x_vec_var) sparsity = nothing end - jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) - jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) + jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place = false) + jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place = true) return FastDifferentiationTwoArgJacobianPrep(_sig, sparsity, jac_exe, jac_exe!) end function DI.value_and_jacobian( - f!, - y, - prep::FastDifferentiationTwoArgJacobianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FastDifferentiationTwoArgJacobianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) @@ -340,14 +340,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!, - y, - jac, - prep::FastDifferentiationTwoArgJacobianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + jac, + prep::FastDifferentiationTwoArgJacobianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) @@ -355,27 +355,27 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!, - y, - prep::FastDifferentiationTwoArgJacobianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FastDifferentiationTwoArgJacobianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) return jac end function DI.jacobian!( - f!, - y, - jac, - prep::FastDifferentiationTwoArgJacobianPrep, - backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + jac, + prep::FastDifferentiationTwoArgJacobianPrep, + backend::Union{AutoFastDifferentiation, AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return jac diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl index 1822b7fe2..a005749fa 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl @@ -26,8 +26,8 @@ DI.inner_preparation_behavior(::AutoFiniteDiff) = DI.PrepareInnerSimple() # see https://github.com/SciML/ADTypes.jl/issues/33 fdtype(::AutoFiniteDiff{fdt}) where {fdt} = fdt -fdjtype(::AutoFiniteDiff{fdt,fdjt}) where {fdt,fdjt} = fdjt -fdhtype(::AutoFiniteDiff{fdt,fdjt,fdht}) where {fdt,fdjt,fdht} = fdht +fdjtype(::AutoFiniteDiff{fdt, fdjt}) where {fdt, fdjt} = fdjt +fdhtype(::AutoFiniteDiff{fdt, fdjt, fdht}) where {fdt, fdjt, fdht} = fdht # see https://docs.sciml.ai/FiniteDiff/stable/#f-Definitions const FUNCTION_INPLACE = Val{true} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index b16edaa05..82b769ef6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -1,6 +1,6 @@ ## Pushforward -struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} +struct FiniteDiffOneArgPushforwardPrep{SIG, C, R, A, D} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} cache::C relstep::R @@ -9,8 +9,8 @@ struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward_nokwarg( - strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) @@ -34,13 +34,13 @@ function DI.prepare_pushforward_nokwarg( end function DI.pushforward( - f, - prep::FiniteDiffOneArgPushforwardPrep{SIG,Nothing}, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + prep::FiniteDiffOneArgPushforwardPrep{SIG, Nothing}, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...) @@ -53,13 +53,13 @@ function DI.pushforward( end function DI.value_and_pushforward( - f, - prep::FiniteDiffOneArgPushforwardPrep{SIG,Nothing}, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + prep::FiniteDiffOneArgPushforwardPrep{SIG, Nothing}, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...) @@ -80,13 +80,13 @@ function DI.value_and_pushforward( end function DI.pushforward( - f, - prep::FiniteDiffOneArgPushforwardPrep{SIG,<:JVPCache}, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + prep::FiniteDiffOneArgPushforwardPrep{SIG, <:JVPCache}, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -97,13 +97,13 @@ function DI.pushforward( end function DI.value_and_pushforward( - f, - prep::FiniteDiffOneArgPushforwardPrep{SIG,<:JVPCache}, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + prep::FiniteDiffOneArgPushforwardPrep{SIG, <:JVPCache}, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -116,7 +116,7 @@ end ## Derivative -struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} +struct FiniteDiffOneArgDerivativePrep{SIG, C, R, A, D} <: DI.DerivativePrep{SIG} _sig::Val{SIG} cache::C relstep::R @@ -125,8 +125,8 @@ struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative_nokwarg( - strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) @@ -153,12 +153,12 @@ end ### Scalar to scalar function DI.derivative( - f, - prep::FiniteDiffOneArgDerivativePrep{SIG,Nothing}, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + prep::FiniteDiffOneArgDerivativePrep{SIG, Nothing}, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -166,12 +166,12 @@ function DI.derivative( end function DI.value_and_derivative( - f, - prep::FiniteDiffOneArgDerivativePrep{SIG,Nothing}, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + prep::FiniteDiffOneArgDerivativePrep{SIG, Nothing}, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -187,12 +187,12 @@ end ### Scalar to array function DI.derivative( - f, - prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + prep::FiniteDiffOneArgDerivativePrep{SIG, <:GradientCache}, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -200,13 +200,13 @@ function DI.derivative( end function DI.derivative!( - f, - der, - prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + der, + prep::FiniteDiffOneArgDerivativePrep{SIG, <:GradientCache}, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -214,12 +214,12 @@ function DI.derivative!( end function DI.value_and_derivative( - f, - prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + prep::FiniteDiffOneArgDerivativePrep{SIG, <:GradientCache}, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) (; relstep, absstep, dir) = prep @@ -228,24 +228,24 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f, - der, - prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f, + der, + prep::FiniteDiffOneArgDerivativePrep{SIG, <:GradientCache}, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return ( - fc(x), finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir) + fc(x), finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir), ) end ## Gradient -struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG} +struct FiniteDiffGradientPrep{SIG, C, R, A, D} <: DI.GradientPrep{SIG} _sig::Val{SIG} cache::C relstep::R @@ -254,8 +254,8 @@ struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG} end function DI.prepare_gradient_nokwarg( - strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) @@ -276,12 +276,12 @@ function DI.prepare_gradient_nokwarg( end function DI.gradient( - f, - prep::FiniteDiffGradientPrep, - backend::AutoFiniteDiff, - x::AbstractArray, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FiniteDiffGradientPrep, + backend::AutoFiniteDiff, + x::AbstractArray, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -289,12 +289,12 @@ function DI.gradient( end function DI.value_and_gradient( - f, - prep::FiniteDiffGradientPrep, - backend::AutoFiniteDiff, - x::AbstractArray, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FiniteDiffGradientPrep, + backend::AutoFiniteDiff, + x::AbstractArray, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -302,13 +302,13 @@ function DI.value_and_gradient( end function DI.gradient!( - f, - grad, - prep::FiniteDiffGradientPrep, - backend::AutoFiniteDiff, - x::AbstractArray, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::FiniteDiffGradientPrep, + backend::AutoFiniteDiff, + x::AbstractArray, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -316,24 +316,24 @@ function DI.gradient!( end function DI.value_and_gradient!( - f, - grad, - prep::FiniteDiffGradientPrep, - backend::AutoFiniteDiff, - x::AbstractArray, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::FiniteDiffGradientPrep, + backend::AutoFiniteDiff, + x::AbstractArray, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return ( - fc(x), finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir) + fc(x), finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir), ) end ## Jacobian -struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} +struct FiniteDiffOneArgJacobianPrep{SIG, C, R, A, D} <: DI.JacobianPrep{SIG} _sig::Val{SIG} cache::C relstep::R @@ -342,8 +342,8 @@ struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian_nokwarg( - strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) @@ -366,12 +366,12 @@ function DI.prepare_jacobian_nokwarg( end function DI.jacobian( - f, - prep::FiniteDiffOneArgJacobianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FiniteDiffOneArgJacobianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -379,12 +379,12 @@ function DI.jacobian( end function DI.value_and_jacobian( - f, - prep::FiniteDiffOneArgJacobianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FiniteDiffOneArgJacobianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) (; relstep, absstep, dir) = prep @@ -393,32 +393,32 @@ function DI.value_and_jacobian( end function DI.jacobian!( - f, - jac, - prep::FiniteDiffOneArgJacobianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::FiniteDiffOneArgJacobianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return copyto!( jac, finite_difference_jacobian( - fc, x, prep.cache; jac_prototype=jac, relstep, absstep, dir + fc, x, prep.cache; jac_prototype = jac, relstep, absstep, dir ), ) end function DI.value_and_jacobian!( - f, - jac, - prep::FiniteDiffOneArgJacobianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::FiniteDiffOneArgJacobianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -428,7 +428,7 @@ function DI.value_and_jacobian!( copyto!( jac, finite_difference_jacobian( - fc, x, prep.cache, y; jac_prototype=jac, relstep, absstep, dir + fc, x, prep.cache, y; jac_prototype = jac, relstep, absstep, dir ), ), ) @@ -436,7 +436,7 @@ end ## Hessian -struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG} +struct FiniteDiffHessianPrep{SIG, C1, C2, RG, AG, RH, AH} <: DI.HessianPrep{SIG} _sig::Val{SIG} gradient_cache::C1 hessian_cache::C2 @@ -447,8 +447,8 @@ struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG} end function DI.prepare_hessian_nokwarg( - strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) @@ -481,72 +481,72 @@ function DI.prepare_hessian_nokwarg( end function DI.hessian( - f, - prep::FiniteDiffHessianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FiniteDiffHessianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep_h, absstep_h) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_hessian( - fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h + fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h ) end function DI.hessian!( - f, - hess, - prep::FiniteDiffHessianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + hess, + prep::FiniteDiffHessianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep_h, absstep_h) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_hessian!( - hess, fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h + hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h ) end function DI.value_gradient_and_hessian( - f, - prep::FiniteDiffHessianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::FiniteDiffHessianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep_g, absstep_g, relstep_h, absstep_h) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) grad = finite_difference_gradient( - fc, x, prep.gradient_cache; relstep=relstep_g, absstep=absstep_g + fc, x, prep.gradient_cache; relstep = relstep_g, absstep = absstep_g ) hess = finite_difference_hessian( - fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h + fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h ) return fc(x), grad, hess end function DI.value_gradient_and_hessian!( - f, - grad, - hess, - prep::FiniteDiffHessianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + hess, + prep::FiniteDiffHessianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep_g, absstep_g, relstep_h, absstep_h) = prep fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) finite_difference_gradient!( - grad, fc, x, prep.gradient_cache; relstep=relstep_g, absstep=absstep_g + grad, fc, x, prep.gradient_cache; relstep = relstep_g, absstep = absstep_g ) finite_difference_hessian!( - hess, fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h + hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h ) return fc(x), grad, hess end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index f8b70c6dd..e30ba1ec7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -1,6 +1,6 @@ ## Pushforward -struct FiniteDiffTwoArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} +struct FiniteDiffTwoArgPushforwardPrep{SIG, C, R, A, D} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} cache::C relstep::R @@ -9,14 +9,14 @@ struct FiniteDiffTwoArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward_nokwarg( - strict::Val, - f!, - y, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f!, + y, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) cache = if x isa Number nothing @@ -38,14 +38,14 @@ function DI.prepare_pushforward_nokwarg( end function DI.value_and_pushforward( - f!, - y, - prep::FiniteDiffTwoArgPushforwardPrep{SIG,Nothing}, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f!, + y, + prep::FiniteDiffTwoArgPushforwardPrep{SIG, Nothing}, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep function step(t::Number, dx) @@ -70,14 +70,14 @@ function DI.value_and_pushforward( end function DI.pushforward( - f!, - y, - prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f!, + y, + prep::FiniteDiffTwoArgPushforwardPrep{SIG, <:JVPCache}, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -90,14 +90,14 @@ function DI.pushforward( end function DI.value_and_pushforward( - f!, - y, - prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f!, + y, + prep::FiniteDiffTwoArgPushforwardPrep{SIG, <:JVPCache}, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -111,15 +111,15 @@ function DI.value_and_pushforward( end function DI.pushforward!( - f!, - y, - ty::NTuple, - prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f!, + y, + ty::NTuple, + prep::FiniteDiffTwoArgPushforwardPrep{SIG, <:JVPCache}, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -131,15 +131,15 @@ function DI.pushforward!( end function DI.value_and_pushforward!( - f!, - y, - ty::NTuple, - prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {SIG,C} + f!, + y, + ty::NTuple, + prep::FiniteDiffTwoArgPushforwardPrep{SIG, <:JVPCache}, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {SIG, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -153,7 +153,7 @@ end ## Derivative -struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} +struct FiniteDiffTwoArgDerivativePrep{SIG, C, R, A, D} <: DI.DerivativePrep{SIG} _sig::Val{SIG} cache::C relstep::R @@ -162,8 +162,8 @@ struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative_nokwarg( - strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) df = similar(y) cache = GradientCache(df, x, fdtype(backend), eltype(y), FUNCTION_INPLACE) @@ -182,20 +182,20 @@ function DI.prepare_derivative_nokwarg( end function DI.prepare!_derivative( - f!, - y, - old_prep::FiniteDiffTwoArgDerivativePrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + old_prep::FiniteDiffTwoArgDerivativePrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, old_prep, backend, x, contexts...) if y isa Vector (; cache) = old_prep - cache.fx isa Union{Number,Nothing} || resize!(cache.fx, length(y)) - cache.c1 isa Union{Number,Nothing} || resize!(cache.c1, length(y)) - cache.c2 isa Union{Number,Nothing} || resize!(cache.c2, length(y)) - cache.c3 isa Union{Number,Nothing} || resize!(cache.c3, length(y)) + cache.fx isa Union{Number, Nothing} || resize!(cache.fx, length(y)) + cache.c1 isa Union{Number, Nothing} || resize!(cache.c1, length(y)) + cache.c2 isa Union{Number, Nothing} || resize!(cache.c2, length(y)) + cache.c3 isa Union{Number, Nothing} || resize!(cache.c3, length(y)) return old_prep else return DI.prepare_derivative_nokwarg( @@ -205,13 +205,13 @@ function DI.prepare!_derivative( end function DI.value_and_derivative( - f!, - y, - prep::FiniteDiffTwoArgDerivativePrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FiniteDiffTwoArgDerivativePrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -221,14 +221,14 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!, - y, - der, - prep::FiniteDiffTwoArgDerivativePrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + der, + prep::FiniteDiffTwoArgDerivativePrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -238,13 +238,13 @@ function DI.value_and_derivative!( end function DI.derivative( - f!, - y, - prep::FiniteDiffTwoArgDerivativePrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FiniteDiffTwoArgDerivativePrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -254,14 +254,14 @@ function DI.derivative( end function DI.derivative!( - f!, - y, - der, - prep::FiniteDiffTwoArgDerivativePrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + der, + prep::FiniteDiffTwoArgDerivativePrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -271,7 +271,7 @@ end ## Jacobian -struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} +struct FiniteDiffTwoArgJacobianPrep{SIG, C, R, A, D} <: DI.JacobianPrep{SIG} _sig::Val{SIG} cache::C relstep::R @@ -280,8 +280,8 @@ struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian_nokwarg( - strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x1 = similar(x) fx = similar(y) @@ -302,20 +302,20 @@ function DI.prepare_jacobian_nokwarg( end function DI.prepare!_jacobian( - f!, - y, - old_prep::FiniteDiffTwoArgJacobianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + old_prep::FiniteDiffTwoArgJacobianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, old_prep, backend, x, contexts...) if x isa Vector && y isa Vector (; cache) = old_prep - cache.x1 isa Union{Number,Nothing} || resize!(cache.x1, length(x)) - cache.x2 isa Union{Number,Nothing} || resize!(cache.x2, length(x)) - cache.fx isa Union{Number,Nothing} || resize!(cache.fx, length(y)) - cache.fx1 isa Union{Number,Nothing} || resize!(cache.fx1, length(y)) + cache.x1 isa Union{Number, Nothing} || resize!(cache.x1, length(x)) + cache.x2 isa Union{Number, Nothing} || resize!(cache.x2, length(x)) + cache.fx isa Union{Number, Nothing} || resize!(cache.fx, length(y)) + cache.fx1 isa Union{Number, Nothing} || resize!(cache.fx1, length(y)) cache.colorvec = 1:length(x) cache.sparsity = nothing return old_prep @@ -327,13 +327,13 @@ function DI.prepare!_jacobian( end function DI.value_and_jacobian( - f!, - y, - prep::FiniteDiffTwoArgJacobianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FiniteDiffTwoArgJacobianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -344,14 +344,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!, - y, - jac, - prep::FiniteDiffTwoArgJacobianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + jac, + prep::FiniteDiffTwoArgJacobianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -361,13 +361,13 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!, - y, - prep::FiniteDiffTwoArgJacobianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::FiniteDiffTwoArgJacobianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) @@ -377,14 +377,14 @@ function DI.jacobian( end function DI.jacobian!( - f!, - y, - jac, - prep::FiniteDiffTwoArgJacobianPrep, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + jac, + prep::FiniteDiffTwoArgJacobianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 366c6ba31..3d3c26f4a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -12,25 +12,25 @@ DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple() ## Pushforward function DI.prepare_pushforward_nokwarg( - strict::Val, - f, - backend::AutoFiniteDifferences, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::AutoFiniteDifferences, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) return DI.NoPushforwardPrep(_sig) end function DI.pushforward( - f, - prep::DI.NoPushforwardPrep, - backend::AutoFiniteDifferences, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoPushforwardPrep, + backend::AutoFiniteDifferences, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) ty = map(tx) do dx @@ -40,40 +40,40 @@ function DI.pushforward( end function DI.value_and_pushforward( - f, - prep::DI.NoPushforwardPrep, - backend::AutoFiniteDifferences, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoPushforwardPrep, + backend::AutoFiniteDifferences, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.pushforward(f, prep, backend, x, tx, contexts...) + DI.pushforward(f, prep, backend, x, tx, contexts...) end ## Pullback function DI.prepare_pullback_nokwarg( - strict::Val, - f, - backend::AutoFiniteDifferences, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::AutoFiniteDifferences, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) end function DI.pullback( - f, - prep::DI.NoPullbackPrep, - backend::AutoFiniteDifferences, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoFiniteDifferences, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -83,70 +83,70 @@ function DI.pullback( end function DI.value_and_pullback( - f, - prep::DI.NoPullbackPrep, - backend::AutoFiniteDifferences, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoFiniteDifferences, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.pullback(f, prep, backend, x, ty, contexts...) + DI.pullback(f, prep, backend, x, ty, contexts...) end ## Gradient function DI.prepare_gradient_nokwarg( - strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoGradientPrep(_sig) end function DI.gradient( - f, - prep::DI.NoGradientPrep, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoGradientPrep, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return only(grad(backend.fdm, fc, x)) end function DI.value_and_gradient( - f, - prep::DI.NoGradientPrep, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoGradientPrep, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end function DI.gradient!( - f, - grad, - prep::DI.NoGradientPrep, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::DI.NoGradientPrep, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end function DI.value_and_gradient!( - f, - grad, - prep::DI.NoGradientPrep, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::DI.NoGradientPrep, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -155,55 +155,55 @@ end ## Jacobian function DI.prepare_jacobian_nokwarg( - strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoJacobianPrep(_sig) end function DI.jacobian( - f, - prep::DI.NoJacobianPrep, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoJacobianPrep, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return only(jacobian(backend.fdm, fc, x)) end function DI.value_and_jacobian( - f, - prep::DI.NoJacobianPrep, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoJacobianPrep, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end function DI.jacobian!( - f, - jac, - prep::DI.NoJacobianPrep, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::DI.NoJacobianPrep, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end function DI.value_and_jacobian!( - f, - jac, - prep::DI.NoJacobianPrep, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::DI.NoJacobianPrep, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl index 55297886f..96316f5b6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl @@ -1,4 +1,4 @@ -function (dw::DI.DifferentiateWith)(x::Dual{T,V,N}) where {T,V,N} +function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N} (; f, backend) = dw xval = myvalue(T, x) tx = mypartials(T, Val(N), x) @@ -6,7 +6,7 @@ function (dw::DI.DifferentiateWith)(x::Dual{T,V,N}) where {T,V,N} return make_dual(T, y, ty) end -function (dw::DI.DifferentiateWith)(x::AbstractArray{Dual{T,V,N}}) where {T,V,N} +function (dw::DI.DifferentiateWith)(x::AbstractArray{Dual{T, V, N}}) where {T, V, N} (; f, backend) = dw xval = myvalue(T, x) tx = mypartials(T, Val(N), x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index 2d5e4c9cc..6cf329ce9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -3,16 +3,16 @@ DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.x DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp) function DI.overloaded_input( - ::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, x, tx::NTuple{B} -) where {F,B} + ::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, x, tx::NTuple{B} + ) where {F, B} T = tag_type(f, backend, x) xdual = make_dual(T, x, tx) return xdual end function DI.overloaded_input( - ::typeof(DI.pushforward), f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B} -) where {F,B} + ::typeof(DI.pushforward), f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B} + ) where {F, B} T = tag_type(f!, backend, x) xdual = make_dual(T, x, tx) return xdual diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 631e27076..a0360e68b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -3,8 +3,8 @@ ### Unprepared (avoid working on `similar(x)`) function DI.value_and_pushforward( - f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C} -) where {F,B,C} + f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context, C} + ) where {F, B, C} T = tag_type(f, backend, x) xdual = make_dual(T, x, tx) contexts_dual = translate(eltype(xdual), contexts) @@ -15,13 +15,13 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - f::F, - ty::NTuple{B}, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f::F, + ty::NTuple{B}, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} T = tag_type(f, backend, x) xdual = make_dual(T, x, tx) contexts_dual = translate(eltype(xdual), contexts) @@ -32,8 +32,8 @@ function DI.value_and_pushforward!( end function DI.pushforward( - f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C} -) where {F,B,C} + f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context, C} + ) where {F, B, C} T = tag_type(f, backend, x) xdual = make_dual(T, x, tx) contexts_dual = translate(eltype(xdual), contexts) @@ -43,13 +43,13 @@ function DI.pushforward( end function DI.pushforward!( - f::F, - ty::NTuple{B}, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + f::F, + ty::NTuple{B}, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, B, C} T = tag_type(f, backend, x) xdual = make_dual(T, x, tx) contexts_dual = translate(eltype(xdual), contexts) @@ -60,7 +60,7 @@ end ### Prepared -struct ForwardDiffOneArgPushforwardPrep{SIG,T,X,CD} <: DI.PushforwardPrep{SIG} +struct ForwardDiffOneArgPushforwardPrep{SIG, T, X, CD} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} _t::Type{T} xdual_tmp::X @@ -68,13 +68,13 @@ struct ForwardDiffOneArgPushforwardPrep{SIG,T,X,CD} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward_nokwarg( - strict::Val, - f::F, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}; -) where {F,B,C} + strict::Val, + f::F, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C} + ) where {F, B, C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) T = tag_type(f, backend, x) if DI.ismutable_array(x) @@ -82,17 +82,17 @@ function DI.prepare_pushforward_nokwarg( else xdual_tmp = nothing end - contexts_dual = translate_toprep(Dual{T,eltype(x),B}, contexts) + contexts_dual = translate_toprep(Dual{T, eltype(x), B}, contexts) return ForwardDiffOneArgPushforwardPrep(_sig, T, xdual_tmp, contexts_dual) end function compute_ydual_onearg( - f::F, - prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, - x::Number, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,B,C} + f::F, + prep::ForwardDiffOneArgPushforwardPrep{SIG, T}, + x::Number, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, B, C} xdual = make_dual(T, x, tx) contexts_dual = translate_prepared(contexts, prep.contexts_dual) ydual = f(xdual, contexts_dual...) @@ -100,12 +100,12 @@ function compute_ydual_onearg( end function compute_ydual_onearg( - f::F, - prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,B,C} + f::F, + prep::ForwardDiffOneArgPushforwardPrep{SIG, T}, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, B, C} if DI.ismutable_array(x) make_dual!(T, prep.xdual_tmp, x, tx) xdual_tmp = prep.xdual_tmp @@ -118,13 +118,13 @@ function compute_ydual_onearg( end function DI.value_and_pushforward( - f::F, - prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,B,C} + f::F, + prep::ForwardDiffOneArgPushforwardPrep{SIG, T}, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, B, C} DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) y = myvalue(T, ydual) @@ -133,14 +133,14 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - f::F, - ty::NTuple, - prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,C} + f::F, + ty::NTuple, + prep::ForwardDiffOneArgPushforwardPrep{SIG, T}, + backend::AutoForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, C} DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) y = myvalue(T, ydual) @@ -149,13 +149,13 @@ function DI.value_and_pushforward!( end function DI.pushforward( - f::F, - prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,B,C} + f::F, + prep::ForwardDiffOneArgPushforwardPrep{SIG, T}, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, B, C} DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) ty = mypartials(T, Val(B), ydual) @@ -163,14 +163,14 @@ function DI.pushforward( end function DI.pushforward!( - f::F, - ty::NTuple, - prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,C} + f::F, + ty::NTuple, + prep::ForwardDiffOneArgPushforwardPrep{SIG, T}, + backend::AutoForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, C} DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) mypartials!(T, ty, ydual) @@ -179,7 +179,7 @@ end ## Derivative -struct ForwardDiffOneArgDerivativePrep{SIG,E} <: DI.DerivativePrep{SIG} +struct ForwardDiffOneArgDerivativePrep{SIG, E} <: DI.DerivativePrep{SIG} _sig::Val{SIG} pushforward_prep::E end @@ -187,28 +187,28 @@ end ### Unprepared function DI.value_and_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {F,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} y, ty = DI.value_and_pushforward(f, backend, x, (oneunit(x),), contexts...) return y, only(ty) end function DI.value_and_derivative!( - f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {F,C} + f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} y, _ = DI.value_and_pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...) return y, der end function DI.derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {F,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} return only(DI.pushforward(f, backend, x, (oneunit(x),), contexts...)) end function DI.derivative!( - f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {F,C} + f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} DI.pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...) return der end @@ -216,8 +216,8 @@ end ### Prepared function DI.prepare_derivative_nokwarg( - strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) pushforward_prep = DI.prepare_pushforward_nokwarg( strict, f, backend, x, (oneunit(x),), contexts... @@ -226,12 +226,12 @@ function DI.prepare_derivative_nokwarg( end function DI.value_and_derivative( - f::F, - prep::ForwardDiffOneArgDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::ForwardDiffOneArgDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) y, ty = DI.value_and_pushforward( f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts... @@ -240,13 +240,13 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f::F, - der, - prep::ForwardDiffOneArgDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + der, + prep::ForwardDiffOneArgDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... @@ -255,12 +255,12 @@ function DI.value_and_derivative!( end function DI.derivative( - f::F, - prep::ForwardDiffOneArgDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::ForwardDiffOneArgDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) return only( DI.pushforward(f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...) @@ -268,13 +268,13 @@ function DI.derivative( end function DI.derivative!( - f::F, - der, - prep::ForwardDiffOneArgDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + der, + prep::ForwardDiffOneArgDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) DI.pushforward!( f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... @@ -287,13 +287,13 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_gradient!( - f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, grad, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = DiffResult(zero(eltype(x)), (grad,)) result = gradient!(result, fc, x) @@ -307,13 +307,13 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = GradientResult(x) result = gradient!(result, fc, x) @@ -325,13 +325,13 @@ function DI.value_and_gradient( end function DI.gradient!( - f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, grad, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return gradient!(grad, fc, x) else @@ -341,13 +341,13 @@ function DI.gradient!( end function DI.gradient( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return gradient(fc, x) else @@ -358,19 +358,19 @@ end ### Prepared -struct ForwardDiffGradientPrep{SIG,C,CD} <: DI.GradientPrep{SIG} +struct ForwardDiffGradientPrep{SIG, C, CD} <: DI.GradientPrep{SIG} _sig::Val{SIG} config::C contexts_dual::CD end function DI.prepare_gradient_nokwarg( - strict::Val, - f::F, - backend::AutoForwardDiff, - x::AbstractArray, - contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, + f::F, + backend::AutoForwardDiff, + x::AbstractArray, + contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) @@ -380,13 +380,13 @@ function DI.prepare_gradient_nokwarg( end function DI.value_and_gradient!( - f::F, - grad, - prep::ForwardDiffGradientPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + grad, + prep::ForwardDiffGradientPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -402,12 +402,12 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f::F, - prep::ForwardDiffGradientPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::ForwardDiffGradientPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -421,13 +421,13 @@ function DI.value_and_gradient( end function DI.gradient!( - f::F, - grad, - prep::ForwardDiffGradientPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + grad, + prep::ForwardDiffGradientPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -439,12 +439,12 @@ function DI.gradient!( end function DI.gradient( - f::F, - prep::ForwardDiffGradientPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::ForwardDiffGradientPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -460,13 +460,13 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_jacobian!( - f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, jac, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) result = DiffResult(y, (jac,)) @@ -481,13 +481,13 @@ function DI.value_and_jacobian!( end function DI.value_and_jacobian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return fc(x), jacobian(fc, x) else @@ -497,13 +497,13 @@ function DI.value_and_jacobian( end function DI.jacobian!( - f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, jac, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return jacobian!(jac, fc, x) else @@ -513,13 +513,13 @@ function DI.jacobian!( end function DI.jacobian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return jacobian(fc, x) else @@ -530,15 +530,15 @@ end ### Prepared -struct ForwardDiffOneArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} +struct ForwardDiffOneArgJacobianPrep{SIG, C, CD} <: DI.JacobianPrep{SIG} _sig::Val{SIG} config::C contexts_dual::CD end function DI.prepare_jacobian_nokwarg( - strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) @@ -548,13 +548,13 @@ function DI.prepare_jacobian_nokwarg( end function DI.value_and_jacobian!( - f::F, - jac, - prep::ForwardDiffOneArgJacobianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + jac, + prep::ForwardDiffOneArgJacobianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -571,12 +571,12 @@ function DI.value_and_jacobian!( end function DI.value_and_jacobian( - f::F, - prep::ForwardDiffOneArgJacobianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::ForwardDiffOneArgJacobianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -588,13 +588,13 @@ function DI.value_and_jacobian( end function DI.jacobian!( - f::F, - jac, - prep::ForwardDiffOneArgJacobianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + jac, + prep::ForwardDiffOneArgJacobianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -606,12 +606,12 @@ function DI.jacobian!( end function DI.jacobian( - f::F, - prep::ForwardDiffOneArgJacobianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::ForwardDiffOneArgJacobianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -625,19 +625,19 @@ end ## Second derivative function DI.prepare_second_derivative_nokwarg( - strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoSecondDerivativePrep(_sig) end function DI.second_derivative( - f::F, - prep::DI.NoSecondDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::DI.NoSecondDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, oneunit(x)) @@ -649,13 +649,13 @@ function DI.second_derivative( end function DI.second_derivative!( - f::F, - der2, - prep::DI.NoSecondDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + der2, + prep::DI.NoSecondDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, oneunit(x)) @@ -667,12 +667,12 @@ function DI.second_derivative!( end function DI.value_derivative_and_second_derivative( - f::F, - prep::DI.NoSecondDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::DI.NoSecondDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, oneunit(x)) @@ -687,14 +687,14 @@ function DI.value_derivative_and_second_derivative( end function DI.value_derivative_and_second_derivative!( - f::F, - der, - der2, - prep::DI.NoSecondDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + der, + der2, + prep::DI.NoSecondDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, oneunit(x)) @@ -713,13 +713,13 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.hessian!( - f::F, hess, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, hess, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return hessian!(hess, fc, x) else @@ -729,13 +729,13 @@ function DI.hessian!( end function DI.hessian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return hessian(fc, x) else @@ -745,18 +745,18 @@ function DI.hessian( end function DI.value_gradient_and_hessian!( - f::F, - grad, - hess, - backend::AutoForwardDiff{chunksize,T}, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C,chunksize,T} + f::F, + grad, + hess, + backend::AutoForwardDiff{chunksize, T}, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = DiffResult(oneunit(eltype(x)), (grad, hess)) result = hessian!(result, fc, x) @@ -771,13 +771,13 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f::F, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = HessianResult(x) result = hessian!(result, fc, x) @@ -790,7 +790,7 @@ end ### Prepared -struct ForwardDiffHessianPrep{SIG,C1,C2,CD} <: DI.HessianPrep{SIG} +struct ForwardDiffHessianPrep{SIG, C1, C2, CD} <: DI.HessianPrep{SIG} _sig::Val{SIG} array_config::C1 result_config::C2 @@ -798,8 +798,8 @@ struct ForwardDiffHessianPrep{SIG,C1,C2,CD} <: DI.HessianPrep{SIG} end function DI.prepare_hessian_nokwarg( - strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) @@ -811,13 +811,13 @@ function DI.prepare_hessian_nokwarg( end function DI.hessian!( - f::F, - hess, - prep::ForwardDiffHessianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + hess, + prep::ForwardDiffHessianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -829,12 +829,12 @@ function DI.hessian!( end function DI.hessian( - f::F, - prep::ForwardDiffHessianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::ForwardDiffHessianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -846,14 +846,14 @@ function DI.hessian( end function DI.value_gradient_and_hessian!( - f::F, - grad, - hess, - prep::ForwardDiffHessianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + grad, + hess, + prep::ForwardDiffHessianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) @@ -870,12 +870,12 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f::F, - prep::ForwardDiffHessianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::ForwardDiffHessianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 29d6cc1ab..6e7d53117 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -1,6 +1,6 @@ ## Pushforward -struct ForwardDiffTwoArgPushforwardPrep{SIG,T,X,Y,CD} <: DI.PushforwardPrep{SIG} +struct ForwardDiffTwoArgPushforwardPrep{SIG, T, X, Y, CD} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} _t::Type{T} xdual_tmp::X @@ -9,14 +9,14 @@ struct ForwardDiffTwoArgPushforwardPrep{SIG,T,X,Y,CD} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}; -) where {F,B,C} + strict::Val, + f!::F, + y, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C} + ) where {F, B, C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) T = tag_type(f!, backend, x) xdual_tmp = make_dual_similar(T, x, tx) @@ -26,13 +26,13 @@ function DI.prepare_pushforward_nokwarg( end function compute_ydual_twoarg( - f!::F, - y, - prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, - x::Number, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,B,C} + f!::F, + y, + prep::ForwardDiffTwoArgPushforwardPrep{SIG, T}, + x::Number, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, B, C} (; ydual_tmp) = prep xdual_tmp = make_dual(T, x, tx) contexts_dual = translate_prepared(contexts, prep.contexts_dual) @@ -41,13 +41,13 @@ function compute_ydual_twoarg( end function compute_ydual_twoarg( - f!::F, - y, - prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,B,C} + f!::F, + y, + prep::ForwardDiffTwoArgPushforwardPrep{SIG, T}, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, B, C} (; ydual_tmp) = prep if DI.ismutable_array(x) make_dual!(T, prep.xdual_tmp, x, tx) @@ -61,14 +61,14 @@ function compute_ydual_twoarg( end function DI.value_and_pushforward( - f!::F, - y, - prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,B,C} + f!::F, + y, + prep::ForwardDiffTwoArgPushforwardPrep{SIG, T}, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, B, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) myvalue!(T, y, ydual_tmp) @@ -77,15 +77,15 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - f!::F, - y, - ty::NTuple, - prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,C} + f!::F, + y, + ty::NTuple, + prep::ForwardDiffTwoArgPushforwardPrep{SIG, T}, + backend::AutoForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) myvalue!(T, y, ydual_tmp) @@ -94,14 +94,14 @@ function DI.value_and_pushforward!( end function DI.pushforward( - f!::F, - y, - prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,B,C} + f!::F, + y, + prep::ForwardDiffTwoArgPushforwardPrep{SIG, T}, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, B, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) ty = mypartials(T, Val(B), ydual_tmp) @@ -109,15 +109,15 @@ function DI.pushforward( end function DI.pushforward!( - f!::F, - y, - ty::NTuple, - prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,T,C} + f!::F, + y, + ty::NTuple, + prep::ForwardDiffTwoArgPushforwardPrep{SIG, T}, + backend::AutoForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, T, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) mypartials!(T, ty, ydual_tmp) @@ -129,9 +129,9 @@ end ### Unprepared, only when tag is not specified function DI.value_and_derivative( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} - if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) + f!::F, y, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} + if (T === Nothing && contexts isa NTuple{C, DI.GeneralizedConstant}) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) result = MutableDiffResult(y, (similar(y),)) result = derivative!(result, fc!, y, x) @@ -143,9 +143,9 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} - if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) + f!::F, y, der, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} + if (T === Nothing && contexts isa NTuple{C, DI.GeneralizedConstant}) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) result = MutableDiffResult(y, (der,)) result = derivative!(result, fc!, y, x) @@ -157,9 +157,9 @@ function DI.value_and_derivative!( end function DI.derivative( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} - if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) + f!::F, y, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} + if (T === Nothing && contexts isa NTuple{C, DI.GeneralizedConstant}) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) return derivative(fc!, y, x) else @@ -169,9 +169,9 @@ function DI.derivative( end function DI.derivative!( - f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} - if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) + f!::F, y, der, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} + if (T === Nothing && contexts isa NTuple{C, DI.GeneralizedConstant}) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) return derivative!(der, fc!, y, x) else @@ -182,15 +182,15 @@ end ### Prepared -struct ForwardDiffTwoArgDerivativePrep{SIG,C,CD} <: DI.DerivativePrep{SIG} +struct ForwardDiffTwoArgDerivativePrep{SIG, C, CD} <: DI.DerivativePrep{SIG} _sig::Val{SIG} config::C contexts_dual::CD end function DI.prepare_derivative_nokwarg( - strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) tag = get_tag(f!, backend, x) config = DerivativeConfig(nothing, y, x, tag) @@ -199,13 +199,13 @@ function DI.prepare_derivative_nokwarg( end function DI.prepare!_derivative( - f!::F, - y, - old_prep::ForwardDiffTwoArgDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {F,C} + f!::F, + y, + old_prep::ForwardDiffTwoArgDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {F, C} DI.check_prep(f!, y, old_prep, backend, x, contexts...) if y isa Vector (; config) = old_prep @@ -219,13 +219,13 @@ function DI.prepare!_derivative( end function DI.value_and_derivative( - f!::F, - y, - prep::ForwardDiffTwoArgDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::ForwardDiffTwoArgDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.fix_tail(f!, contexts_dual...) @@ -239,14 +239,14 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!::F, - y, - der, - prep::ForwardDiffTwoArgDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + der, + prep::ForwardDiffTwoArgDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.fix_tail(f!, contexts_dual...) @@ -260,13 +260,13 @@ function DI.value_and_derivative!( end function DI.derivative( - f!::F, - y, - prep::ForwardDiffTwoArgDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::ForwardDiffTwoArgDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.fix_tail(f!, contexts_dual...) @@ -278,14 +278,14 @@ function DI.derivative( end function DI.derivative!( - f!::F, - y, - der, - prep::ForwardDiffTwoArgDerivativePrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + der, + prep::ForwardDiffTwoArgDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.fix_tail(f!, contexts_dual...) @@ -301,13 +301,13 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_jacobian( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f!::F, y, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) @@ -320,13 +320,13 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f!::F, y, jac, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) result = MutableDiffResult(y, (jac,)) result = jacobian!(result, fc!, y, x) @@ -338,13 +338,13 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f!::F, y, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) return jacobian(fc!, y, x) else @@ -354,13 +354,13 @@ function DI.jacobian( end function DI.jacobian!( - f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} -) where {F,C,chunksize,T} + f!::F, y, jac, backend::AutoForwardDiff{chunksize, T}, x, contexts::Vararg{DI.Context, C} + ) where {F, C, chunksize, T} if ( - isnothing(chunksize) && - T === Nothing && - contexts isa NTuple{C,DI.GeneralizedConstant} - ) + isnothing(chunksize) && + T === Nothing && + contexts isa NTuple{C, DI.GeneralizedConstant} + ) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) return jacobian!(jac, fc!, y, x) else @@ -371,15 +371,15 @@ end ### Prepared -struct ForwardDiffTwoArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} +struct ForwardDiffTwoArgJacobianPrep{SIG, C, CD} <: DI.JacobianPrep{SIG} _sig::Val{SIG} config::C contexts_dual::CD end function DI.prepare_jacobian_nokwarg( - strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {F,C} + strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f!, backend, x) @@ -389,13 +389,13 @@ function DI.prepare_jacobian_nokwarg( end function DI.prepare!_jacobian( - f!::F, - y, - old_prep::ForwardDiffTwoArgJacobianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {F,C} + f!::F, + y, + old_prep::ForwardDiffTwoArgJacobianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {F, C} DI.check_prep(f!, y, old_prep, backend, x, contexts...) if x isa Vector && y isa Vector (; config) = old_prep @@ -411,13 +411,13 @@ function DI.prepare!_jacobian( end function DI.value_and_jacobian( - f!::F, - y, - prep::ForwardDiffTwoArgJacobianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::ForwardDiffTwoArgJacobianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.fix_tail(f!, contexts_dual...) @@ -432,14 +432,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!::F, - y, - jac, - prep::ForwardDiffTwoArgJacobianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + jac, + prep::ForwardDiffTwoArgJacobianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.fix_tail(f!, contexts_dual...) @@ -453,13 +453,13 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!::F, - y, - prep::ForwardDiffTwoArgJacobianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::ForwardDiffTwoArgJacobianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.fix_tail(f!, contexts_dual...) @@ -471,14 +471,14 @@ function DI.jacobian( end function DI.jacobian!( - f!::F, - y, - jac, - prep::ForwardDiffTwoArgJacobianPrep, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + jac, + prep::ForwardDiffTwoArgJacobianPrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.fix_tail(f!, contexts_dual...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 14b17f2dd..e0d5b2dfe 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -8,10 +8,10 @@ function DI.pick_batchsize(::AutoForwardDiff{chunksize}, N::Integer) where {chun end function DI.threshold_batchsize( - backend::AutoForwardDiff{chunksize1}, chunksize2::Integer -) where {chunksize1} + backend::AutoForwardDiff{chunksize1}, chunksize2::Integer + ) where {chunksize1} chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2) - return AutoForwardDiff(; chunksize, tag=backend.tag) + return AutoForwardDiff(; chunksize, tag = backend.tag) end choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x) @@ -19,39 +19,39 @@ choose_chunk(::AutoForwardDiff{chunksize}, x) where {chunksize} = Chunk{chunksiz get_tag(f, backend::AutoForwardDiff, x) = backend.tag -function get_tag(f::F, backend::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize} +function get_tag(f::F, backend::AutoForwardDiff{chunksize, Nothing}, x) where {F, chunksize} return Tag(f, eltype(x)) end -tag_type(::AutoForwardDiff{chunksize,T}) where {chunksize,T} = T +tag_type(::AutoForwardDiff{chunksize, T}) where {chunksize, T} = T tag_type(f::F, backend::AutoForwardDiff, x) where {F} = typeof(get_tag(f, backend, x)) dual_type(config::DerivativeConfig) = eltype(config.duals) dual_type(config::GradientConfig) = eltype(config.duals) -dual_type(config::JacobianConfig{T,V,N}) where {T,V,N} = Dual{T,V,N} +dual_type(config::JacobianConfig{T, V, N}) where {T, V, N} = Dual{T, V, N} dual_type(config::HessianConfig) = dual_type(config.gradient_config) -function make_dual_similar(::Type{T}, x::Number, tx::NTuple{B}) where {T,B} +function make_dual_similar(::Type{T}, x::Number, tx::NTuple{B}) where {T, B} return Dual{T}(x, tx...) end -function make_dual_similar(::Type{T}, x, tx::NTuple{B}) where {T,B} - return similar(x, Dual{T,eltype(x),B}) +function make_dual_similar(::Type{T}, x, tx::NTuple{B}) where {T, B} + return similar(x, Dual{T, eltype(x), B}) end function make_dual(::Type{T}, x::Number, dx::Number) where {T} return Dual{T}(x, dx) end -function make_dual(::Type{T}, x::Number, tx::NTuple{B}) where {T,B} +function make_dual(::Type{T}, x::Number, tx::NTuple{B}) where {T, B} return Dual{T}(x, tx...) end -function make_dual(::Type{T}, x, tx::NTuple{B}) where {T,B} +function make_dual(::Type{T}, x, tx::NTuple{B}) where {T, B} return Dual{T}.(x, tx...) end -function make_dual!(::Type{T}, xdual, x, tx::NTuple{B}) where {T,B} +function make_dual!(::Type{T}, xdual, x, tx::NTuple{B}) where {T, B} return xdual .= Dual{T}.(x, tx...) end @@ -63,19 +63,19 @@ myderivative(::Type{T}, ydual::Number) where {T} = extract_derivative(T, ydual) myderivative(::Type{T}, ydual) where {T} = myderivative.(T, ydual) myderivative!(::Type{T}, dy, ydual) where {T} = dy .= myderivative.(T, ydual) -function mypartials(::Type{T}, ::Val{B}, ydual::Number) where {T,B} +function mypartials(::Type{T}, ::Val{B}, ydual::Number) where {T, B} return ntuple(Val(B)) do b partials(T, ydual, b) end end -function mypartials(::Type{T}, ::Val{B}, ydual) where {T,B} +function mypartials(::Type{T}, ::Val{B}, ydual) where {T, B} return ntuple(Val(B)) do b partials.(T, ydual, b) end end -function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} +function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T, B} for b in eachindex(ty) ty[b] .= partials.(T, ydual, b) end @@ -83,16 +83,16 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} end function _translate( - ::Type{D}, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache} -) where {D<:Dual} + ::Type{D}, c::Union{DI.GeneralizedConstant, DI.ConstantOrCache} + ) where {D <: Dual} return DI.unwrap(c) end -function _translate(::Type{D}, c::DI.Cache) where {D<:Dual} +function _translate(::Type{D}, c::DI.Cache) where {D <: Dual} c0 = DI.unwrap(c) return DI.recursive_similar(c0, D) end -function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} +function translate(::Type{D}, contexts::NTuple{C, DI.Context}) where {D <: Dual, C} new_contexts = map(contexts) do c _translate(D, c) end @@ -100,30 +100,30 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} end function _translate_toprep( - ::Type{D}, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache} -) where {D<:Dual} + ::Type{D}, c::Union{DI.GeneralizedConstant, DI.ConstantOrCache} + ) where {D <: Dual} return nothing end -function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual} +function _translate_toprep(::Type{D}, c::DI.Cache) where {D <: Dual} c0 = DI.unwrap(c) return DI.recursive_similar(c0, D) end -function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} +function translate_toprep(::Type{D}, contexts::NTuple{C, DI.Context}) where {D <: Dual, C} new_contexts = map(contexts) do c _translate_toprep(D, c) end return new_contexts end -function _translate_prepared(c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}, _pc) +function _translate_prepared(c::Union{DI.GeneralizedConstant, DI.ConstantOrCache}, _pc) return DI.unwrap(c) end _translate_prepared(_c::DI.Cache, pc) = pc function translate_prepared( - contexts::NTuple{C,DI.Context}, prep_contexts::NTuple{C,Any} -) where {C} + contexts::NTuple{C, DI.Context}, prep_contexts::NTuple{C, Any} + ) where {C} new_contexts = map(contexts, prep_contexts) do c, pc _translate_prepared(c, pc) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index 45257d21c..d39f4747e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -2,19 +2,19 @@ # Contains either a single pre-allocated initial TPS # or a vector of pre-allocated TPSs. -struct GTPSAOneArgPushforwardPrep{SIG,X} <: DI.PushforwardPrep{SIG} +struct GTPSAOneArgPushforwardPrep{SIG, X} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} xt::X end function DI.prepare_pushforward_nokwarg( - strict::Val, - f::F, - backend::AutoGTPSA{D}, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}; -) where {F,D,C} + strict::Val, + f::F, + backend::AutoGTPSA{D}, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C} + ) where {F, D, C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) # because we even if we did multiple we will add up the derivatives of each at the end. @@ -24,29 +24,29 @@ function DI.prepare_pushforward_nokwarg( d = Descriptor(1, 1) # 1 variable to first order end if x isa Number - xt = TPS{promote_type(typeof(first(tx)), typeof(x), Float64)}(; use=d) + xt = TPS{promote_type(typeof(first(tx)), typeof(x), Float64)}(; use = d) return GTPSAOneArgPushforwardPrep(_sig, xt) else xt = similar(x, TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}) for i in eachindex(xt) - xt[i] = TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}(; use=d) + xt[i] = TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}(; use = d) end return GTPSAOneArgPushforwardPrep(_sig, xt) end end function DI.pushforward( - f, - prep::GTPSAOneArgPushforwardPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + prep::GTPSAOneArgPushforwardPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) ty = map(tx) do dx - foreach((t, xi, dxi) -> (t[0]=xi; t[1]=dxi), prep.xt, x, dx) + foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) yt = fc(prep.xt) if yt isa Number return yt[1] @@ -59,19 +59,19 @@ function DI.pushforward( end function DI.pushforward!( - f, - ty::NTuple, - prep::GTPSAOneArgPushforwardPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + ty::NTuple, + prep::GTPSAOneArgPushforwardPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - foreach((t, xi, dxi) -> (t[0]=xi; t[1]=dxi), prep.xt, x, dx) + foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) yt = fc(prep.xt) map!(t -> t[1], dy, yt) end @@ -79,13 +79,13 @@ function DI.pushforward!( end function DI.value_and_pushforward( - f, - prep::GTPSAOneArgPushforwardPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + prep::GTPSAOneArgPushforwardPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) ty = DI.pushforward(f, prep, backend, x, tx, contexts...) y = f(x, map(DI.unwrap, contexts)...) # TODO: optimize @@ -93,14 +93,14 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - f, - ty::NTuple, - prep::GTPSAOneArgPushforwardPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + ty::NTuple, + prep::GTPSAOneArgPushforwardPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) y = f(x, map(DI.unwrap, contexts)...) # TODO: optimize @@ -109,15 +109,15 @@ end ## Gradient # Contains a vector of pre-allocated TPSs. -struct GTPSAOneArgGradientPrep{SIG,X} <: DI.GradientPrep{SIG} +struct GTPSAOneArgGradientPrep{SIG, X} <: DI.GradientPrep{SIG} _sig::Val{SIG} xt::X end # Unlike JVP, this requires us to use all variables function DI.prepare_gradient_nokwarg( - strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} -) where {D,C} + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant, C} + ) where {D, C} _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor @@ -127,7 +127,7 @@ function DI.prepare_gradient_nokwarg( xt = similar(x, TPS{promote_type(eltype(x), Float64)}) j = 1 for i in eachindex(xt) - xt[i] = TPS{promote_type(eltype(x), Float64)}(; use=d) + xt[i] = TPS{promote_type(eltype(x), Float64)}(; use = d) xt[i][j] = 1 j += 1 end @@ -135,72 +135,72 @@ function DI.prepare_gradient_nokwarg( end function DI.gradient( - f, prep::GTPSAOneArgGradientPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} -) where {C} + f, prep::GTPSAOneArgGradientPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) grad = similar(x, GTPSA.numtype(yt)) - GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) + GTPSA.gradient!(grad, yt; include_params = true, unsafe_inbounds = true) return grad end function DI.gradient!( - f, - grad, - prep::GTPSAOneArgGradientPrep, - backend::AutoGTPSA, - x, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + grad, + prep::GTPSAOneArgGradientPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) - GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) + GTPSA.gradient!(grad, yt; include_params = true, unsafe_inbounds = true) return grad end function DI.value_and_gradient( - f, prep::GTPSAOneArgGradientPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} -) where {C} + f, prep::GTPSAOneArgGradientPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) grad = similar(x, GTPSA.numtype(yt)) - GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) + GTPSA.gradient!(grad, yt; include_params = true, unsafe_inbounds = true) return yt[0], grad end function DI.value_and_gradient!( - f, - grad, - prep::GTPSAOneArgGradientPrep, - backend::AutoGTPSA, - x, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + grad, + prep::GTPSAOneArgGradientPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) - GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) + GTPSA.gradient!(grad, yt; include_params = true, unsafe_inbounds = true) return yt[0], grad end ## Jacobian # Contains a vector of pre-allocated TPSs -struct GTPSAOneArgJacobianPrep{SIG,X} <: DI.JacobianPrep{SIG} +struct GTPSAOneArgJacobianPrep{SIG, X} <: DI.JacobianPrep{SIG} _sig::Val{SIG} xt::X end # To materialize the entire Jacobian we use all variables function DI.prepare_jacobian_nokwarg( - strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} -) where {D,C} + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant, C} + ) where {D, C} _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor @@ -212,7 +212,7 @@ function DI.prepare_jacobian_nokwarg( xt = similar(x, TPS{promote_type(eltype(x), Float64)}) j = 1 for i in eachindex(xt) - xt[i] = TPS{promote_type(eltype(x), Float64)}(; use=d) + xt[i] = TPS{promote_type(eltype(x), Float64)}(; use = d) xt[i][j] = 1 j += 1 end @@ -220,91 +220,91 @@ function DI.prepare_jacobian_nokwarg( end function DI.jacobian( - f, prep::GTPSAOneArgJacobianPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} -) where {C} + f, prep::GTPSAOneArgJacobianPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) jac = similar(x, GTPSA.numtype(eltype(yt)), (length(yt), length(x))) - GTPSA.jacobian!(jac, yt; include_params=true, unsafe_inbounds=true) + GTPSA.jacobian!(jac, yt; include_params = true, unsafe_inbounds = true) return jac end function DI.jacobian!( - f, - jac, - prep::GTPSAOneArgJacobianPrep, - backend::AutoGTPSA, - x, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + jac, + prep::GTPSAOneArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) - GTPSA.jacobian!(jac, yt; include_params=true, unsafe_inbounds=true) + GTPSA.jacobian!(jac, yt; include_params = true, unsafe_inbounds = true) return jac end function DI.value_and_jacobian( - f, prep::GTPSAOneArgJacobianPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} -) where {C} + f, prep::GTPSAOneArgJacobianPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) jac = similar(x, GTPSA.numtype(eltype(yt)), (length(yt), length(x))) - GTPSA.jacobian!(jac, yt; include_params=true, unsafe_inbounds=true) + GTPSA.jacobian!(jac, yt; include_params = true, unsafe_inbounds = true) y = map(t -> t[0], yt) return y, jac end function DI.value_and_jacobian!( - f, - jac, - prep::GTPSAOneArgJacobianPrep, - backend::AutoGTPSA, - x, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + jac, + prep::GTPSAOneArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) - GTPSA.jacobian!(jac, yt; include_params=true, unsafe_inbounds=true) + GTPSA.jacobian!(jac, yt; include_params = true, unsafe_inbounds = true) y = map(t -> t[0], yt) return y, jac end ## Second derivative # Contains single pre-allocated TPS -struct GTPSAOneArgSecondDerivativePrep{SIG,X} <: DI.SecondDerivativePrep{SIG} +struct GTPSAOneArgSecondDerivativePrep{SIG, X} <: DI.SecondDerivativePrep{SIG} _sig::Val{SIG} xt::X end function DI.prepare_second_derivative_nokwarg( - strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} -) where {D,C} + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant, C} + ) where {D, C} _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else d = Descriptor(1, 2) end - xt = TPS{promote_type(typeof(x), Float64)}(; use=d) + xt = TPS{promote_type(typeof(x), Float64)}(; use = d) xt[1] = 1 # Set slope return GTPSAOneArgSecondDerivativePrep(_sig, xt) end function DI.second_derivative( - f, - prep::GTPSAOneArgSecondDerivativePrep, - backend::AutoGTPSA{D}, - x, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + prep::GTPSAOneArgSecondDerivativePrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -327,13 +327,13 @@ function DI.second_derivative( end function DI.second_derivative!( - f, - der2, - prep::GTPSAOneArgSecondDerivativePrep, - backend::AutoGTPSA{D}, - x, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + der2, + prep::GTPSAOneArgSecondDerivativePrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -350,12 +350,12 @@ function DI.second_derivative!( end function DI.value_derivative_and_second_derivative( - f, - prep::GTPSAOneArgSecondDerivativePrep, - backend::AutoGTPSA{D}, - x, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + prep::GTPSAOneArgSecondDerivativePrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -380,14 +380,14 @@ function DI.value_derivative_and_second_derivative( end function DI.value_derivative_and_second_derivative!( - f, - der, - der2, - prep::GTPSAOneArgSecondDerivativePrep, - backend::AutoGTPSA{D}, - x, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + der, + der2, + prep::GTPSAOneArgSecondDerivativePrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -408,15 +408,15 @@ end ## Hessian # Stores allocated array of TPS and an array for the monomial coefficient # indexing in GTPSA.cycle! (which is used if a Descriptor is specified) -struct GTPSAOneArgHessianPrep{SIG,X,M} <: DI.HessianPrep{SIG} +struct GTPSAOneArgHessianPrep{SIG, X, M} <: DI.HessianPrep{SIG} _sig::Val{SIG} xt::X m::M end function DI.prepare_hessian_nokwarg( - strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} -) where {D,C} + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant, C} + ) where {D, C} _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor @@ -435,7 +435,7 @@ function DI.prepare_hessian_nokwarg( # linear with the variables. j = 1 for i in eachindex(xt) - xt[i] = TPS{promote_type(eltype(x), Float64)}(; use=d) + xt[i] = TPS{promote_type(eltype(x), Float64)}(; use = d) xt[i][j] = 1 j += 1 end @@ -444,12 +444,12 @@ function DI.prepare_hessian_nokwarg( end function DI.hessian( - f, - prep::GTPSAOneArgHessianPrep, - backend::AutoGTPSA{D}, - x, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + prep::GTPSAOneArgHessianPrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -459,22 +459,22 @@ function DI.hessian( GTPSA.hessian!( hess, yt; - include_params=true, - unsafe_inbounds=true, - unsafe_fast=unsafe_fast, - tmp_mono=prep.m, + include_params = true, + unsafe_inbounds = true, + unsafe_fast = unsafe_fast, + tmp_mono = prep.m, ) return hess end function DI.hessian!( - f, - hess, - prep::GTPSAOneArgHessianPrep, - backend::AutoGTPSA{D}, - x, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + hess, + prep::GTPSAOneArgHessianPrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -483,75 +483,75 @@ function DI.hessian!( GTPSA.hessian!( hess, yt; - include_params=true, - unsafe_inbounds=true, - unsafe_fast=unsafe_fast, - tmp_mono=prep.m, + include_params = true, + unsafe_inbounds = true, + unsafe_fast = unsafe_fast, + tmp_mono = prep.m, ) return hess end function DI.value_gradient_and_hessian( - f, - prep::GTPSAOneArgHessianPrep, - backend::AutoGTPSA{D}, - x, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + prep::GTPSAOneArgHessianPrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) grad = similar(x, GTPSA.numtype(yt)) - GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) + GTPSA.gradient!(grad, yt; include_params = true, unsafe_inbounds = true) hess = similar(x, GTPSA.numtype(yt), (length(x), length(x))) unsafe_fast = D == Nothing ? true : false GTPSA.hessian!( hess, yt; - include_params=true, - unsafe_inbounds=true, - unsafe_fast=unsafe_fast, - tmp_mono=prep.m, + include_params = true, + unsafe_inbounds = true, + unsafe_fast = unsafe_fast, + tmp_mono = prep.m, ) return yt[0], grad, hess end function DI.value_gradient_and_hessian!( - f, - grad, - hess, - prep::GTPSAOneArgHessianPrep, - backend::AutoGTPSA{D}, - x, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + grad, + hess, + prep::GTPSAOneArgHessianPrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) - GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) + GTPSA.gradient!(grad, yt; include_params = true, unsafe_inbounds = true) unsafe_fast = D == Nothing ? true : false GTPSA.hessian!( hess, yt; - include_params=true, - unsafe_inbounds=true, - unsafe_fast=unsafe_fast, - tmp_mono=prep.m, + include_params = true, + unsafe_inbounds = true, + unsafe_fast = unsafe_fast, + tmp_mono = prep.m, ) return yt[0], grad, hess end -struct GTPSAOneArgHVPPrep{SIG,E,H} <: DI.HVPPrep{SIG} +struct GTPSAOneArgHVPPrep{SIG, E, H} <: DI.HVPPrep{SIG} _sig::Val{SIG} hessprep::E hess::H end function DI.prepare_hvp_nokwarg( - strict::Val, f, backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C} -) where {C} + strict::Val, f, backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant, C} + ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) hessprep = DI.prepare_hessian_nokwarg(strict, f, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) @@ -560,13 +560,13 @@ function DI.prepare_hvp_nokwarg( end function DI.hvp( - f, - prep::GTPSAOneArgHVPPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + prep::GTPSAOneArgHVPPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hessian!(f, prep.hess, prep.hessprep, backend, x, contexts...) tg = map(tx) do dx @@ -585,14 +585,14 @@ function DI.hvp( end function DI.hvp!( - f, - tg::NTuple, - prep::GTPSAOneArgHVPPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f, + tg::NTuple, + prep::GTPSAOneArgHVPPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hessian!(f, prep.hess, prep.hessprep, backend, x, contexts...) for b in eachindex(tg) @@ -610,13 +610,13 @@ function DI.hvp!( end function DI.gradient_and_hvp( - f, - prep::GTPSAOneArgHVPPrep, - backend::AutoGTPSA{D}, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + prep::GTPSAOneArgHVPPrep, + backend::AutoGTPSA{D}, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, tx, contexts...) grad = similar(x, eltype(prep.hess)) DI.value_gradient_and_hessian!( @@ -638,15 +638,15 @@ function DI.gradient_and_hvp( end function DI.gradient_and_hvp!( - f, - grad, - tg, - prep::GTPSAOneArgHVPPrep, - backend::AutoGTPSA{D}, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {D,C} + f, + grad, + tg, + prep::GTPSAOneArgHVPPrep, + backend::AutoGTPSA{D}, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {D, C} DI.check_prep(f, prep, backend, x, tx, contexts...) DI.value_gradient_and_hessian!( f, grad, prep.hess, prep.hessprep, backend, x, contexts... diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl index 5edbdf9e5..425b16c7b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl @@ -4,21 +4,21 @@ # or a vector of pre-allocated TPSs. # # Output: Contains a vector of pre-allocated TPSs -struct GTPSATwoArgPushforwardPrep{SIG,X,Y} <: DI.PushforwardPrep{SIG} +struct GTPSATwoArgPushforwardPrep{SIG, X, Y} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} xt::X yt::Y end function DI.prepare_pushforward_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoGTPSA{D}, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}; -) where {F,D,C} + strict::Val, + f!::F, + y, + backend::AutoGTPSA{D}, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C} + ) where {F, D, C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) # because we even if we did multiple we will add up the derivatives of each at the end. @@ -28,35 +28,35 @@ function DI.prepare_pushforward_nokwarg( d = Descriptor(1, 1) # 1 variable to first order end if x isa Number - xt = TPS{promote_type(typeof(first(tx)), typeof(x), Float64)}(; use=d) + xt = TPS{promote_type(typeof(first(tx)), typeof(x), Float64)}(; use = d) else xt = similar(x, TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}) for i in eachindex(xt) - xt[i] = TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}(; use=d) + xt[i] = TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}(; use = d) end end yt = similar(y, TPS{promote_type(eltype(y), Float64)}) for i in eachindex(yt) - yt[i] = TPS{promote_type(eltype(y), Float64)}(; use=d) + yt[i] = TPS{promote_type(eltype(y), Float64)}(; use = d) end return GTPSATwoArgPushforwardPrep(_sig, xt, yt) end function DI.pushforward( - f!, - y, - prep::GTPSATwoArgPushforwardPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f!, + y, + prep::GTPSATwoArgPushforwardPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) ty = map(tx) do dx - foreach((t, xi, dxi) -> (t[0]=xi; t[1]=dxi), prep.xt, x, dx) + foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) fc!(prep.yt, prep.xt) dy = map(t -> t[1], prep.yt) return dy @@ -66,20 +66,20 @@ function DI.pushforward( end function DI.pushforward!( - f!, - y, - ty::NTuple, - prep::GTPSATwoArgPushforwardPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f!, + y, + ty::NTuple, + prep::GTPSATwoArgPushforwardPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - foreach((t, xi, dxi) -> (t[0]=xi; t[1]=dxi), prep.xt, x, dx) + foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) fc!(prep.yt, prep.xt) map!(t -> t[1], dy, prep.yt) end @@ -88,29 +88,29 @@ function DI.pushforward!( end function DI.value_and_pushforward( - f!, - y, - prep::GTPSATwoArgPushforwardPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f!, + y, + prep::GTPSATwoArgPushforwardPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) return y, ty end function DI.value_and_pushforward!( - f!, - y, - ty::NTuple, - prep::GTPSATwoArgPushforwardPrep, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}, -) where {C} + f!, + y, + ty::NTuple, + prep::GTPSATwoArgPushforwardPrep, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) return y, ty @@ -119,15 +119,15 @@ end ## Jacobian # Input: Contains a vector of pre-allocated TPSs # Output: Contains a vector of pre-allocated TPSs -struct GTPSATwoArgJacobianPrep{SIG,X,Y} <: DI.JacobianPrep{SIG} +struct GTPSATwoArgJacobianPrep{SIG, X, Y} <: DI.JacobianPrep{SIG} _sig::Val{SIG} xt::X yt::Y end function DI.prepare_jacobian_nokwarg( - strict::Val, f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} -) where {D,C} + strict::Val, f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant, C} + ) where {D, C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor @@ -139,7 +139,7 @@ function DI.prepare_jacobian_nokwarg( xt = similar(x, TPS{promote_type(eltype(x), Float64)}) j = 1 for i in eachindex(xt) - xt[i] = TPS{promote_type(eltype(x), Float64)}(; use=d) + xt[i] = TPS{promote_type(eltype(x), Float64)}(; use = d) xt[i][j] = 1 j += 1 end @@ -147,70 +147,70 @@ function DI.prepare_jacobian_nokwarg( yt = similar(y, TPS{promote_type(eltype(y), Float64)}) for i in eachindex(yt) - yt[i] = TPS{promote_type(eltype(y), Float64)}(; use=d) + yt[i] = TPS{promote_type(eltype(y), Float64)}(; use = d) end return GTPSATwoArgJacobianPrep(_sig, xt, yt) end function DI.jacobian( - f!, - y, - prep::GTPSATwoArgJacobianPrep, - backend::AutoGTPSA, - x, - contexts::Vararg{DI.Constant,C}, -) where {C} + f!, + y, + prep::GTPSATwoArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) fc!(prep.yt, prep.xt) jac = similar(x, GTPSA.numtype(eltype(prep.yt)), (length(prep.yt), length(x))) - GTPSA.jacobian!(jac, prep.yt; include_params=true, unsafe_inbounds=true) + GTPSA.jacobian!(jac, prep.yt; include_params = true, unsafe_inbounds = true) map!(t -> t[0], y, prep.yt) return jac end function DI.jacobian!( - f!, - y, - jac, - prep::GTPSATwoArgJacobianPrep, - backend::AutoGTPSA, - x, - contexts::Vararg{DI.Constant,C}, -) where {C} + f!, + y, + jac, + prep::GTPSATwoArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) fc!(prep.yt, prep.xt) - GTPSA.jacobian!(jac, prep.yt; include_params=true, unsafe_inbounds=true) + GTPSA.jacobian!(jac, prep.yt; include_params = true, unsafe_inbounds = true) map!(t -> t[0], y, prep.yt) return jac end function DI.value_and_jacobian( - f!, - y, - prep::GTPSATwoArgJacobianPrep, - backend::AutoGTPSA, - x, - contexts::Vararg{DI.Constant,C}, -) where {C} + f!, + y, + prep::GTPSATwoArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) jac = DI.jacobian(f!, y, prep, backend, x, contexts...) # y set on line 151 return y, jac end function DI.value_and_jacobian!( - f!, - y, - jac, - prep::GTPSATwoArgJacobianPrep, - backend::AutoGTPSA, - x, - contexts::Vararg{DI.Constant,C}, -) where {C} + f!, + y, + jac, + prep::GTPSATwoArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) return y, jac diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index d037498c9..740384101 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -31,7 +31,7 @@ using Mooncake: _copy_output, _copy_to_output!! -const AnyAutoMooncake{C} = Union{AutoMooncake{C},AutoMooncakeForward{C}} +const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}} DI.check_available(::AnyAutoMooncake{C}) where {C} = true diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 3b4fb91c3..ad2d9f7c7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,10 +1,10 @@ -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Any} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any} struct MooncakeDifferentiateWithError <: Exception F::Type X::Type Y::Type - function MooncakeDifferentiateWithError(::F, ::X, ::Y) where {F,X,Y} + function MooncakeDifferentiateWithError(::F, ::X, ::Y) where {F, X, Y} return new(F, X, Y) end end @@ -48,8 +48,8 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number end function Mooncake.rrule!!( - dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}} -) + dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}} + ) primal_func = primal(dw) primal_x = primal(x) fdata_arg = x.dx diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index e8bee9ca8..abf1565a6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -1,19 +1,19 @@ ## Pushforward -struct MooncakeOneArgPushforwardPrep{SIG,Tcache,DX} <: DI.PushforwardPrep{SIG} +struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} cache::Tcache dx_righttype::DX end function DI.prepare_pushforward_nokwarg( - strict::Val, - f::F, - backend::AutoMooncakeForward, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, + f::F, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) config = get_config(backend) cache = prepare_derivative_cache( @@ -25,13 +25,13 @@ function DI.prepare_pushforward_nokwarg( end function DI.value_and_pushforward( - f::F, - prep::MooncakeOneArgPushforwardPrep, - backend::AutoMooncakeForward, - x::X, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C,X} + f::F, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C, X} DI.check_prep(f, prep, backend, x, tx, contexts...) ys_and_ty = map(tx) do dx dx_righttype = @@ -52,26 +52,26 @@ function DI.value_and_pushforward( end function DI.pushforward( - f::F, - prep::MooncakeOneArgPushforwardPrep, - backend::AutoMooncakeForward, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C} + f::F, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2] end function DI.value_and_pushforward!( - f::F, - ty::NTuple, - prep::MooncakeOneArgPushforwardPrep, - backend::AutoMooncakeForward, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C} + f::F, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f, prep, backend, x, tx, contexts...) y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) @@ -79,14 +79,14 @@ function DI.value_and_pushforward!( end function DI.pushforward!( - f::F, - ty::NTuple, - prep::MooncakeOneArgPushforwardPrep, - backend::AutoMooncakeForward, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C} + f::F, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f, prep, backend, x, tx, contexts...) DI.value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...) return ty diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 56b655b2e..26539d305 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -1,6 +1,6 @@ ## Pushforward -struct MooncakeTwoArgPushforwardPrep{SIG,Tcache,DX,DY} <: DI.PushforwardPrep{SIG} +struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} cache::Tcache dx_righttype::DX @@ -8,14 +8,14 @@ struct MooncakeTwoArgPushforwardPrep{SIG,Tcache,DX,DY} <: DI.PushforwardPrep{SIG end function DI.prepare_pushforward_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoMooncakeForward, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) config = get_config(backend) cache = prepare_derivative_cache( @@ -33,14 +33,14 @@ function DI.prepare_pushforward_nokwarg( end function DI.value_and_pushforward( - f!::F, - y, - prep::MooncakeTwoArgPushforwardPrep, - backend::AutoMooncakeForward, - x::X, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C,X} + f!::F, + y, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C, X} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx dx_righttype = @@ -60,28 +60,28 @@ function DI.value_and_pushforward( end function DI.pushforward( - f!::F, - y, - prep::MooncakeTwoArgPushforwardPrep, - backend::AutoMooncakeForward, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C} + f!::F, + y, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) return DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2] end function DI.value_and_pushforward!( - f!::F, - y::Y, - ty::NTuple, - prep::MooncakeTwoArgPushforwardPrep, - backend::AutoMooncakeForward, - x::X, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C,X,Y} + f!::F, + y::Y, + ty::NTuple, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C, X, Y} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) foreach(tx, ty) do dx, dy dx_righttype = @@ -101,15 +101,15 @@ function DI.value_and_pushforward!( end function DI.pushforward!( - f!::F, - y, - ty::NTuple, - prep::MooncakeTwoArgPushforwardPrep, - backend::AutoMooncakeForward, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C} + f!::F, + y, + ty::NTuple, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) return ty diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 131035e6b..209367d5a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,14 +1,14 @@ ## Pullback -struct MooncakeOneArgPullbackPrep{SIG,Tcache,DY} <: DI.PullbackPrep{SIG} +struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache dy_righttype::DY end function DI.prepare_pullback_nokwarg( - strict::Val, f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) config = get_config(backend) cache = prepare_pullback_cache( @@ -21,13 +21,13 @@ function DI.prepare_pullback_nokwarg( end function DI.value_and_pullback( - f::F, - prep::MooncakeOneArgPullbackPrep{Y}, - backend::AutoMooncake, - x, - ty::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,Y,C} + f::F, + prep::MooncakeOneArgPullbackPrep{Y}, + backend::AutoMooncake, + x, + ty::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, Y, C} DI.check_prep(f, prep, backend, x, ty, contexts...) dy = only(ty) dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) @@ -38,13 +38,13 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f::F, - prep::MooncakeOneArgPullbackPrep{Y}, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,Y,C} + f::F, + prep::MooncakeOneArgPullbackPrep{Y}, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, Y, C} DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy dy_righttype = @@ -60,14 +60,14 @@ function DI.value_and_pullback( end function DI.value_and_pullback!( - f::F, - tx::NTuple, - prep::MooncakeOneArgPullbackPrep, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + tx::NTuple, + prep::MooncakeOneArgPullbackPrep, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, ty, contexts...) y, new_tx = DI.value_and_pullback(f, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) @@ -75,40 +75,40 @@ function DI.value_and_pullback!( end function DI.pullback( - f::F, - prep::MooncakeOneArgPullbackPrep, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::MooncakeOneArgPullbackPrep, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, ty, contexts...) return DI.value_and_pullback(f, prep, backend, x, ty, contexts...)[2] end function DI.pullback!( - f::F, - tx::NTuple, - prep::MooncakeOneArgPullbackPrep, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + tx::NTuple, + prep::MooncakeOneArgPullbackPrep, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, ty, contexts...) return DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2] end ## Gradient -struct MooncakeGradientPrep{SIG,Tcache} <: DI.GradientPrep{SIG} +struct MooncakeGradientPrep{SIG, Tcache} <: DI.GradientPrep{SIG} _sig::Val{SIG} cache::Tcache end function DI.prepare_gradient_nokwarg( - strict::Val, f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C} -) where {F,C} + strict::Val, f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) cache = prepare_gradient_cache( @@ -119,25 +119,25 @@ function DI.prepare_gradient_nokwarg( end function DI.value_and_gradient( - f::F, - prep::MooncakeGradientPrep, - backend::AutoMooncake, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::MooncakeGradientPrep, + backend::AutoMooncake, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) return y, _copy_output(new_grad) end function DI.value_and_gradient!( - f::F, - grad, - prep::MooncakeGradientPrep, - backend::AutoMooncake, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + grad, + prep::MooncakeGradientPrep, + backend::AutoMooncake, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) copyto!(grad, new_grad) @@ -145,25 +145,25 @@ function DI.value_and_gradient!( end function DI.gradient( - f::F, - prep::MooncakeGradientPrep, - backend::AutoMooncake, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::MooncakeGradientPrep, + backend::AutoMooncake, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) _, grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return grad end function DI.gradient!( - f::F, - grad, - prep::MooncakeGradientPrep, - backend::AutoMooncake, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + grad, + prep::MooncakeGradientPrep, + backend::AutoMooncake, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) DI.value_and_gradient!(f, grad, prep, backend, x, contexts...) return grad diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index d0bbf282d..da3e5b217 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,4 +1,4 @@ -struct MooncakeTwoArgPullbackPrep{SIG,Tcache,DY,F} <: DI.PullbackPrep{SIG} +struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache dy_righttype::DY @@ -6,14 +6,14 @@ struct MooncakeTwoArgPullbackPrep{SIG,Tcache,DY,F} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) target_function = function (f!, y, x, contexts...) f!(y, x, contexts...) @@ -26,8 +26,8 @@ function DI.prepare_pullback_nokwarg( y, x, map(DI.unwrap, contexts)...; - debug_mode=config.debug_mode, - silence_debug_messages=config.silence_debug_messages, + debug_mode = config.debug_mode, + silence_debug_messages = config.silence_debug_messages, ) dy_righttype_after = zero_tangent(y) prep = MooncakeTwoArgPullbackPrep(_sig, cache, dy_righttype_after, target_function) @@ -35,14 +35,14 @@ function DI.prepare_pullback_nokwarg( end function DI.value_and_pullback( - f!::F, - y, - prep::MooncakeTwoArgPullbackPrep, - backend::AutoMooncake, - x, - ty::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::MooncakeTwoArgPullbackPrep, + backend::AutoMooncake, + x, + ty::NTuple{1}, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) dy = only(ty) # Prepare cotangent to add after the forward pass. @@ -62,14 +62,14 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f!::F, - y, - prep::MooncakeTwoArgPullbackPrep, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::MooncakeTwoArgPullbackPrep, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy dy_righttype_after = copyto!(prep.dy_righttype, dy) @@ -89,15 +89,15 @@ function DI.value_and_pullback( end function DI.value_and_pullback!( - f!::F, - y, - tx::NTuple, - prep::MooncakeTwoArgPullbackPrep, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + tx::NTuple, + prep::MooncakeTwoArgPullbackPrep, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) _, new_tx = DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) @@ -105,28 +105,28 @@ function DI.value_and_pullback!( end function DI.pullback( - f!::F, - y, - prep::MooncakeTwoArgPullbackPrep, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::MooncakeTwoArgPullbackPrep, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) return DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...)[2] end function DI.pullback!( - f!::F, - y, - tx::NTuple, - prep::MooncakeTwoArgPullbackPrep, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + tx::NTuple, + prep::MooncakeTwoArgPullbackPrep, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) return DI.value_and_pullback!(f!, y, tx, prep, backend, x, ty, contexts...)[2] end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index 666b7f5f9..ec3af69b9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -10,8 +10,8 @@ using DiffResults: DiffResults const FDExt = Base.get_extension(DI, :DifferentiationInterfaceForwardDiffExt) @assert !isnothing(FDExt) -function single_threaded(backend::AutoPolyesterForwardDiff{chunksize,T}) where {chunksize,T} - return AutoForwardDiff(; chunksize, tag=backend.tag) +function single_threaded(backend::AutoPolyesterForwardDiff{chunksize, T}) where {chunksize, T} + return AutoForwardDiff(; chunksize, tag = backend.tag) end DI.check_available(::AutoPolyesterForwardDiff) = true diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/misc.jl index 2262e407c..d95a44293 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/misc.jl @@ -1,11 +1,11 @@ function DI.overloaded_input( - ::typeof(DI.pushforward), f::F, backend::AutoPolyesterForwardDiff, x, tx::NTuple{B} -) where {F,B} + ::typeof(DI.pushforward), f::F, backend::AutoPolyesterForwardDiff, x, tx::NTuple{B} + ) where {F, B} return DI.overloaded_input(DI.pushforward, f, single_threaded(backend), x, tx) end function DI.overloaded_input( - ::typeof(DI.pushforward), f!::F, y, backend::AutoPolyesterForwardDiff, x, tx::NTuple{B} -) where {F,B} + ::typeof(DI.pushforward), f!::F, y, backend::AutoPolyesterForwardDiff, x, tx::NTuple{B} + ) where {F, B} return DI.overloaded_input(DI.pushforward, f!, y, single_threaded(backend), x, tx) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 44af4237c..65e2c5df9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -1,19 +1,18 @@ - ## Pushforward -struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SIG} +struct PolyesterForwardDiffOneArgPushforwardPrep{SIG, P} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} single_threaded_prep::P end function DI.prepare_pushforward_nokwarg( - strict::Val, - f, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) single_threaded_prep = DI.prepare_pushforward_nokwarg( strict, f, single_threaded(backend), x, tx, contexts... @@ -22,13 +21,13 @@ function DI.prepare_pushforward_nokwarg( end function DI.value_and_pushforward( - f, - prep::PolyesterForwardDiffOneArgPushforwardPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffOneArgPushforwardPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.value_and_pushforward( f, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... @@ -36,14 +35,14 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - f, - ty::NTuple, - prep::PolyesterForwardDiffOneArgPushforwardPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + ty::NTuple, + prep::PolyesterForwardDiffOneArgPushforwardPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.value_and_pushforward!( f, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... @@ -51,13 +50,13 @@ function DI.value_and_pushforward!( end function DI.pushforward( - f, - prep::PolyesterForwardDiffOneArgPushforwardPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffOneArgPushforwardPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.pushforward( f, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... @@ -65,14 +64,14 @@ function DI.pushforward( end function DI.pushforward!( - f, - ty::NTuple, - prep::PolyesterForwardDiffOneArgPushforwardPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + ty::NTuple, + prep::PolyesterForwardDiffOneArgPushforwardPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.pushforward!( f, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... @@ -81,14 +80,14 @@ end ## Derivative -struct PolyesterForwardDiffOneArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} +struct PolyesterForwardDiffOneArgDerivativePrep{SIG, P} <: DI.DerivativePrep{SIG} _sig::Val{SIG} single_threaded_prep::P end function DI.prepare_derivative_nokwarg( - strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_derivative_nokwarg( strict, f, single_threaded(backend), x, contexts... @@ -97,12 +96,12 @@ function DI.prepare_derivative_nokwarg( end function DI.value_and_derivative( - f, - prep::PolyesterForwardDiffOneArgDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffOneArgDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.value_and_derivative( f, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -110,13 +109,13 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f, - der, - prep::PolyesterForwardDiffOneArgDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der, + prep::PolyesterForwardDiffOneArgDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.value_and_derivative!( f, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -124,12 +123,12 @@ function DI.value_and_derivative!( end function DI.derivative( - f, - prep::PolyesterForwardDiffOneArgDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffOneArgDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.derivative( f, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -137,13 +136,13 @@ function DI.derivative( end function DI.derivative!( - f, - der, - prep::PolyesterForwardDiffOneArgDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der, + prep::PolyesterForwardDiffOneArgDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.derivative!( f, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -152,19 +151,19 @@ end ## Gradient -struct PolyesterForwardDiffGradientPrep{SIG,chunksize,P} <: DI.GradientPrep{SIG} +struct PolyesterForwardDiffGradientPrep{SIG, chunksize, P} <: DI.GradientPrep{SIG} _sig::Val{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end function DI.prepare_gradient_nokwarg( - strict::Val, - f, - backend::AutoPolyesterForwardDiff{chunksize}, - x, - contexts::Vararg{DI.Context,C}; -) where {chunksize,C} + strict::Val, + f, + backend::AutoPolyesterForwardDiff{chunksize}, + x, + contexts::Vararg{DI.Context, C} + ) where {chunksize, C} _sig = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) @@ -178,15 +177,15 @@ function DI.prepare_gradient_nokwarg( end function DI.value_and_gradient!( - f, - grad, - prep::PolyesterForwardDiffGradientPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::PolyesterForwardDiffGradientPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - if contexts isa NTuple{C,DI.GeneralizedConstant} + if contexts isa NTuple{C, DI.GeneralizedConstant} fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) threaded_gradient!(fc, grad, x, prep.chunk) return fc(x), grad @@ -199,15 +198,15 @@ function DI.value_and_gradient!( end function DI.gradient!( - f, - grad, - prep::PolyesterForwardDiffGradientPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::PolyesterForwardDiffGradientPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - if contexts isa NTuple{C,DI.GeneralizedConstant} + if contexts isa NTuple{C, DI.GeneralizedConstant} fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) threaded_gradient!(fc, grad, x, prep.chunk) return grad @@ -220,42 +219,42 @@ function DI.gradient!( end function DI.value_and_gradient( - f, - prep::PolyesterForwardDiffGradientPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffGradientPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.value_and_gradient!(f, similar(x), prep, backend, x, contexts...) end function DI.gradient( - f, - prep::PolyesterForwardDiffGradientPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffGradientPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.gradient!(f, similar(x), prep, backend, x, contexts...) end ## Jacobian -struct PolyesterForwardDiffOneArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPrep{SIG} +struct PolyesterForwardDiffOneArgJacobianPrep{SIG, chunksize, P} <: DI.JacobianPrep{SIG} _sig::Val{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end function DI.prepare_jacobian_nokwarg( - strict::Val, - f, - backend::AutoPolyesterForwardDiff{chunksize}, - x, - contexts::Vararg{DI.Context,C}; -) where {chunksize,C} + strict::Val, + f, + backend::AutoPolyesterForwardDiff{chunksize}, + x, + contexts::Vararg{DI.Context, C} + ) where {chunksize, C} _sig = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) @@ -269,15 +268,15 @@ function DI.prepare_jacobian_nokwarg( end function DI.value_and_jacobian!( - f, - jac, - prep::PolyesterForwardDiffOneArgJacobianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::PolyesterForwardDiffOneArgJacobianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - if contexts isa NTuple{C,DI.GeneralizedConstant} + if contexts isa NTuple{C, DI.GeneralizedConstant} fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return fc(x), threaded_jacobian!(fc, jac, x, prep.chunk) else @@ -288,15 +287,15 @@ function DI.value_and_jacobian!( end function DI.jacobian!( - f, - jac, - prep::PolyesterForwardDiffOneArgJacobianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::PolyesterForwardDiffOneArgJacobianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - if contexts isa NTuple{C,DI.GeneralizedConstant} + if contexts isa NTuple{C, DI.GeneralizedConstant} fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return threaded_jacobian!(fc, jac, x, prep.chunk) else @@ -307,12 +306,12 @@ function DI.jacobian!( end function DI.value_and_jacobian( - f, - prep::PolyesterForwardDiffOneArgJacobianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffOneArgJacobianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) @@ -320,12 +319,12 @@ function DI.value_and_jacobian( end function DI.jacobian( - f, - prep::PolyesterForwardDiffOneArgJacobianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffOneArgJacobianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) @@ -334,14 +333,14 @@ end ## Hessian -struct PolyesterForwardDiffHessianPrep{SIG,P} <: DI.HessianPrep{SIG} +struct PolyesterForwardDiffHessianPrep{SIG, P} <: DI.HessianPrep{SIG} _sig::Val{SIG} single_threaded_prep::P end function DI.prepare_hessian_nokwarg( - strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_hessian_nokwarg( strict, f, single_threaded(backend), x, contexts... @@ -350,12 +349,12 @@ function DI.prepare_hessian_nokwarg( end function DI.hessian( - f, - prep::PolyesterForwardDiffHessianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffHessianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.hessian( f, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -363,13 +362,13 @@ function DI.hessian( end function DI.hessian!( - f, - hess, - prep::PolyesterForwardDiffHessianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + hess, + prep::PolyesterForwardDiffHessianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.hessian!( f, hess, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -377,12 +376,12 @@ function DI.hessian!( end function DI.value_gradient_and_hessian( - f, - prep::PolyesterForwardDiffHessianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffHessianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.value_gradient_and_hessian( f, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -390,14 +389,14 @@ function DI.value_gradient_and_hessian( end function DI.value_gradient_and_hessian!( - f, - grad, - hess, - prep::PolyesterForwardDiffHessianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + hess, + prep::PolyesterForwardDiffHessianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.value_gradient_and_hessian!( f, grad, hess, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -406,14 +405,14 @@ end ## Second derivative -struct PolyesterForwardDiffOneArgSecondDerivativePrep{SIG,P} <: DI.SecondDerivativePrep{SIG} +struct PolyesterForwardDiffOneArgSecondDerivativePrep{SIG, P} <: DI.SecondDerivativePrep{SIG} _sig::Val{SIG} single_threaded_prep::P end function DI.prepare_second_derivative_nokwarg( - strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_second_derivative_nokwarg( strict, f, single_threaded(backend), x, contexts... @@ -422,12 +421,12 @@ function DI.prepare_second_derivative_nokwarg( end function DI.value_derivative_and_second_derivative( - f, - prep::PolyesterForwardDiffOneArgSecondDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.value_derivative_and_second_derivative( f, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -435,14 +434,14 @@ function DI.value_derivative_and_second_derivative( end function DI.value_derivative_and_second_derivative!( - f, - der, - der2, - prep::PolyesterForwardDiffOneArgSecondDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der, + der2, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.value_derivative_and_second_derivative!( f, der, der2, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -450,12 +449,12 @@ function DI.value_derivative_and_second_derivative!( end function DI.second_derivative( - f, - prep::PolyesterForwardDiffOneArgSecondDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.second_derivative( f, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -463,13 +462,13 @@ function DI.second_derivative( end function DI.second_derivative!( - f, - der2, - prep::PolyesterForwardDiffOneArgSecondDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der2, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.second_derivative!( f, der2, prep.single_threaded_prep, single_threaded(backend), x, contexts... diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 2b270e4cd..2fdfb405a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -1,19 +1,19 @@ ## Pushforward -struct PolyesterForwardDiffTwoArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SIG} +struct PolyesterForwardDiffTwoArgPushforwardPrep{SIG, P} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} single_threaded_prep::P end function DI.prepare_pushforward_nokwarg( - strict::Val, - f!, - y, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f!, + y, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) single_threaded_prep = DI.prepare_pushforward_nokwarg( strict, f!, y, single_threaded(backend), x, tx, contexts... @@ -22,14 +22,14 @@ function DI.prepare_pushforward_nokwarg( end function DI.value_and_pushforward( - f!, - y, - prep::PolyesterForwardDiffTwoArgPushforwardPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) return DI.value_and_pushforward( f!, y, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... @@ -37,15 +37,15 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - f!, - y, - ty::NTuple, - prep::PolyesterForwardDiffTwoArgPushforwardPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + ty::NTuple, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) return DI.value_and_pushforward!( f!, y, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... @@ -53,14 +53,14 @@ function DI.value_and_pushforward!( end function DI.pushforward( - f!, - y, - prep::PolyesterForwardDiffTwoArgPushforwardPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) return DI.pushforward( f!, y, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... @@ -68,15 +68,15 @@ function DI.pushforward( end function DI.pushforward!( - f!, - y, - ty::NTuple, - prep::PolyesterForwardDiffTwoArgPushforwardPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + ty::NTuple, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) return DI.pushforward!( f!, y, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... @@ -85,14 +85,14 @@ end ## Derivative -struct PolyesterForwardDiffTwoArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} +struct PolyesterForwardDiffTwoArgDerivativePrep{SIG, P} <: DI.DerivativePrep{SIG} _sig::Val{SIG} single_threaded_prep::P end function DI.prepare_derivative_nokwarg( - strict::Val, f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_derivative_nokwarg( strict, f!, y, single_threaded(backend), x, contexts... @@ -101,13 +101,13 @@ function DI.prepare_derivative_nokwarg( end function DI.value_and_derivative( - f!, - y, - prep::PolyesterForwardDiffTwoArgDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::PolyesterForwardDiffTwoArgDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) return DI.value_and_derivative( f!, y, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -115,14 +115,14 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!, - y, - der, - prep::PolyesterForwardDiffTwoArgDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + der, + prep::PolyesterForwardDiffTwoArgDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) return DI.value_and_derivative!( f!, y, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -130,13 +130,13 @@ function DI.value_and_derivative!( end function DI.derivative( - f!, - y, - prep::PolyesterForwardDiffTwoArgDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::PolyesterForwardDiffTwoArgDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) return DI.derivative( f!, y, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -144,14 +144,14 @@ function DI.derivative( end function DI.derivative!( - f!, - y, - der, - prep::PolyesterForwardDiffTwoArgDerivativePrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + der, + prep::PolyesterForwardDiffTwoArgDerivativePrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) return DI.derivative!( f!, y, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... @@ -160,20 +160,20 @@ end ## Jacobian -struct PolyesterForwardDiffTwoArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPrep{SIG} +struct PolyesterForwardDiffTwoArgJacobianPrep{SIG, chunksize, P} <: DI.JacobianPrep{SIG} _sig::Val{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end function DI.prepare_jacobian_nokwarg( - strict::Val, - f!, - y, - backend::AutoPolyesterForwardDiff{chunksize}, - x, - contexts::Vararg{DI.Context,C}; -) where {chunksize,C} + strict::Val, + f!, + y, + backend::AutoPolyesterForwardDiff{chunksize}, + x, + contexts::Vararg{DI.Context, C} + ) where {chunksize, C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) @@ -187,15 +187,15 @@ function DI.prepare_jacobian_nokwarg( end function DI.value_and_jacobian( - f!, - y, - prep::PolyesterForwardDiffTwoArgJacobianPrep, - backend::AutoPolyesterForwardDiff{K}, - x, - contexts::Vararg{DI.Context,C}, -) where {K,C} + f!, + y, + prep::PolyesterForwardDiffTwoArgJacobianPrep, + backend::AutoPolyesterForwardDiff{K}, + x, + contexts::Vararg{DI.Context, C}, + ) where {K, C} DI.check_prep(f!, y, prep, backend, x, contexts...) - if contexts isa NTuple{C,DI.GeneralizedConstant} + if contexts isa NTuple{C, DI.GeneralizedConstant} fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) threaded_jacobian!(fc!, y, jac, x, prep.chunk) @@ -209,16 +209,16 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!, - y, - jac, - prep::PolyesterForwardDiffTwoArgJacobianPrep, - backend::AutoPolyesterForwardDiff{K}, - x, - contexts::Vararg{DI.Context,C}, -) where {K,C} + f!, + y, + jac, + prep::PolyesterForwardDiffTwoArgJacobianPrep, + backend::AutoPolyesterForwardDiff{K}, + x, + contexts::Vararg{DI.Context, C}, + ) where {K, C} DI.check_prep(f!, y, prep, backend, x, contexts...) - if contexts isa NTuple{C,DI.GeneralizedConstant} + if contexts isa NTuple{C, DI.GeneralizedConstant} fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) threaded_jacobian!(fc!, y, jac, x, prep.chunk) fc!(y, x) @@ -231,15 +231,15 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!, - y, - prep::PolyesterForwardDiffTwoArgJacobianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::PolyesterForwardDiffTwoArgJacobianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) - if contexts isa NTuple{C,DI.GeneralizedConstant} + if contexts isa NTuple{C, DI.GeneralizedConstant} fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) threaded_jacobian!(fc!, y, jac, x, prep.chunk) @@ -252,16 +252,16 @@ function DI.jacobian( end function DI.jacobian!( - f!, - y, - jac, - prep::PolyesterForwardDiffTwoArgJacobianPrep, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + jac, + prep::PolyesterForwardDiffTwoArgJacobianPrep, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) - if contexts isa NTuple{C,DI.GeneralizedConstant} + if contexts isa NTuple{C, DI.GeneralizedConstant} fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) threaded_jacobian!(fc!, y, jac, x, prep.chunk) return jac diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/utils.jl index 0322023fb..350dc1be4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/utils.jl @@ -7,8 +7,8 @@ function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, N::Integer) end function DI.threshold_batchsize( - backend::AutoPolyesterForwardDiff{chunksize1}, chunksize2::Integer -) where {chunksize1} + backend::AutoPolyesterForwardDiff{chunksize1}, chunksize2::Integer + ) where {chunksize1} chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2) - return AutoPolyesterForwardDiff(; chunksize, tag=backend.tag) + return AutoPolyesterForwardDiff(; chunksize, tag = backend.tag) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 753e67b20..093ecca9d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -1,20 +1,20 @@ ## Pullback function DI.prepare_pullback_nokwarg( - strict::Val, f, backend::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) end function DI.value_and_pullback( - f, - prep::DI.NoPullbackPrep, - backend::AutoReverseDiff, - x::AbstractArray, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) @@ -30,14 +30,14 @@ function DI.value_and_pullback( end function DI.value_and_pullback!( - f, - tx::NTuple, - prep::DI.NoPullbackPrep, - backend::AutoReverseDiff, - x::AbstractArray, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + tx::NTuple, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) @@ -55,13 +55,13 @@ function DI.value_and_pullback!( end function DI.value_and_pullback( - f, - prep::DI.NoPullbackPrep, - backend::AutoReverseDiff, - x::Number, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, + x::Number, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) x_array = [x] f_array(x_array, args...) = f(only(x_array), args...) @@ -73,15 +73,15 @@ end ### Without contexts -struct ReverseDiffGradientPrep{SIG,C,T} <: DI.GradientPrep{SIG} +struct ReverseDiffGradientPrep{SIG, C, T} <: DI.GradientPrep{SIG} _sig::Val{SIG} config::C tape::T end function DI.prepare_gradient_nokwarg( - strict::Val, f, backend::AutoReverseDiff{compile}, x -) where {compile} + strict::Val, f, backend::AutoReverseDiff{compile}, x + ) where {compile} _sig = DI.signature(f, backend, x; strict) if compile tape = ReverseDiff.compile(GradientTape(f, x)) @@ -93,8 +93,8 @@ function DI.prepare_gradient_nokwarg( end function DI.value_and_gradient!( - f, grad, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, grad, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) result = MutableDiffResult(zero(eltype(x)), (grad,)) # ReverseDiff#251 if compile @@ -106,8 +106,8 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) # GradientResult tries to mutate an SArray result = MutableDiffResult(zero(eltype(x)), (similar(x),)) @@ -120,8 +120,8 @@ function DI.value_and_gradient( end function DI.gradient!( - f, grad, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, grad, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) if compile return gradient!(grad, prep.tape, x) @@ -131,8 +131,8 @@ function DI.gradient!( end function DI.gradient( - f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) if compile return gradient!(prep.tape, x) @@ -144,21 +144,21 @@ end ### With contexts function DI.prepare_gradient_nokwarg( - strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) config = GradientConfig(x) return ReverseDiffGradientPrep(_sig, config, nothing) end function DI.value_and_gradient!( - f, - grad, - prep::ReverseDiffGradientPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::ReverseDiffGradientPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = MutableDiffResult(zero(eltype(x)), (grad,)) # ReverseDiff#251 @@ -167,12 +167,12 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f, - prep::ReverseDiffGradientPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ReverseDiffGradientPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) # GradientResult tries to mutate an SArray @@ -182,25 +182,25 @@ function DI.value_and_gradient( end function DI.gradient!( - f, - grad, - prep::ReverseDiffGradientPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::ReverseDiffGradientPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return gradient!(grad, fc, x, prep.config) end function DI.gradient( - f, - prep::ReverseDiffGradientPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ReverseDiffGradientPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return gradient(fc, x, prep.config) @@ -210,15 +210,15 @@ end ### Without contexts -struct ReverseDiffOneArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} +struct ReverseDiffOneArgJacobianPrep{SIG, C, T} <: DI.JacobianPrep{SIG} _sig::Val{SIG} config::C tape::T end function DI.prepare_jacobian_nokwarg( - strict::Val, f, backend::AutoReverseDiff{compile}, x -) where {compile} + strict::Val, f, backend::AutoReverseDiff{compile}, x + ) where {compile} _sig = DI.signature(f, backend, x; strict) if compile tape = ReverseDiff.compile(JacobianTape(f, x)) @@ -230,8 +230,8 @@ function DI.prepare_jacobian_nokwarg( end function DI.value_and_jacobian!( - f, jac, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, jac, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) y = f(x) result = DiffResult(y, (jac,)) @@ -246,8 +246,8 @@ function DI.value_and_jacobian!( end function DI.value_and_jacobian( - f, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) if compile return f(x), jacobian!(prep.tape, x) @@ -257,8 +257,8 @@ function DI.value_and_jacobian( end function DI.jacobian!( - f, jac, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, jac, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) if compile return jacobian!(jac, prep.tape, x) @@ -268,8 +268,8 @@ function DI.jacobian!( end function DI.jacobian( - f, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) if compile return jacobian!(prep.tape, x) @@ -281,21 +281,21 @@ end ### With contexts function DI.prepare_jacobian_nokwarg( - strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) config = JacobianConfig(x) return ReverseDiffOneArgJacobianPrep(_sig, config, nothing) end function DI.value_and_jacobian!( - f, - jac, - prep::ReverseDiffOneArgJacobianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::ReverseDiffOneArgJacobianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) @@ -307,37 +307,37 @@ function DI.value_and_jacobian!( end function DI.value_and_jacobian( - f, - prep::ReverseDiffOneArgJacobianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ReverseDiffOneArgJacobianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return fc(x), jacobian(fc, x, prep.config) end function DI.jacobian!( - f, - jac, - prep::ReverseDiffOneArgJacobianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::ReverseDiffOneArgJacobianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return jacobian!(jac, fc, x, prep.config) end function DI.jacobian( - f, - prep::ReverseDiffOneArgJacobianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ReverseDiffOneArgJacobianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return jacobian(fc, x, prep.config) @@ -347,7 +347,7 @@ end ### Without contexts -struct ReverseDiffHessianPrep{SIG,G<:ReverseDiffGradientPrep,HC,HT} <: DI.HessianPrep{SIG} +struct ReverseDiffHessianPrep{SIG, G <: ReverseDiffGradientPrep, HC, HT} <: DI.HessianPrep{SIG} _sig::Val{SIG} gradient_prep::G hessian_config::HC @@ -355,8 +355,8 @@ struct ReverseDiffHessianPrep{SIG,G<:ReverseDiffGradientPrep,HC,HT} <: DI.Hessia end function DI.prepare_hessian_nokwarg( - strict::Val, f, backend::AutoReverseDiff{compile}, x -) where {compile} + strict::Val, f, backend::AutoReverseDiff{compile}, x + ) where {compile} _sig = DI.signature(f, backend, x; strict) gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x) if compile @@ -369,8 +369,8 @@ function DI.prepare_hessian_nokwarg( end function DI.hessian!( - f, hess, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, hess, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) if compile return hessian!(hess, prep.hessian_tape, x) @@ -380,8 +380,8 @@ function DI.hessian!( end function DI.hessian( - f, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) if compile return hessian!(prep.hessian_tape, x) @@ -391,8 +391,8 @@ function DI.hessian( end function DI.value_gradient_and_hessian!( - f, grad, hess, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, grad, hess, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) y = f(x) DI.gradient!(f, grad, prep.gradient_prep, backend, x) @@ -401,8 +401,8 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f, prep, backend, x) y = f(x) grad = DI.gradient(f, prep.gradient_prep, backend, x) @@ -413,8 +413,8 @@ end ### With contexts function DI.prepare_hessian_nokwarg( - strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...) hessian_config = HessianConfig(x) @@ -422,39 +422,39 @@ function DI.prepare_hessian_nokwarg( end function DI.hessian!( - f, - hess, - prep::ReverseDiffHessianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + hess, + prep::ReverseDiffHessianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return hessian!(hess, fc, x, prep.hessian_config) end function DI.hessian( - f, - prep::ReverseDiffHessianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ReverseDiffHessianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return hessian(fc, x, prep.hessian_config) end function DI.value_gradient_and_hessian!( - f, - grad, - hess, - prep::ReverseDiffHessianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + hess, + prep::ReverseDiffHessianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) @@ -463,12 +463,12 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f, - prep::ReverseDiffHessianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ReverseDiffHessianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index b58ba7dd4..dbc780917 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -1,14 +1,14 @@ ## Pullback function DI.prepare_pullback_nokwarg( - strict::Val, - f!, - y, - backend::AutoReverseDiff, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f!, + y, + backend::AutoReverseDiff, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) end @@ -16,14 +16,14 @@ end ### Array in function DI.value_and_pullback( - f!, - y, - prep::DI.NoPullbackPrep, - backend::AutoReverseDiff, - x::AbstractArray, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) function dotclosure(x, dy) @@ -39,15 +39,15 @@ function DI.value_and_pullback( end function DI.value_and_pullback!( - f!, - y, - tx::NTuple, - prep::DI.NoPullbackPrep, - backend::AutoReverseDiff, - x::AbstractArray, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + tx::NTuple, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) function dotclosure(x, dy) @@ -64,14 +64,14 @@ function DI.value_and_pullback!( end function DI.pullback( - f!, - y, - prep::DI.NoPullbackPrep, - backend::AutoReverseDiff, - x::AbstractArray, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) function dotclosure(x, dy) @@ -86,15 +86,15 @@ function DI.pullback( end function DI.pullback!( - f!, - y, - tx::NTuple, - prep::DI.NoPullbackPrep, - backend::AutoReverseDiff, - x::AbstractArray, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + tx::NTuple, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) function dotclosure(x, dy) @@ -112,14 +112,14 @@ end ### Number in, not supported function DI.value_and_pullback( - f!, - y, - prep::DI.NoPullbackPrep, - backend::AutoReverseDiff, - x::Number, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, + x::Number, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) x_array = [x] function f!_array(_y::AbstractArray, _x_array, args...) @@ -133,15 +133,15 @@ end ### Without contexts -struct ReverseDiffTwoArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} +struct ReverseDiffTwoArgJacobianPrep{SIG, C, T} <: DI.JacobianPrep{SIG} _sig::Val{SIG} config::C tape::T end function DI.prepare_jacobian_nokwarg( - strict::Val, f!, y, backend::AutoReverseDiff{compile}, x -) where {compile} + strict::Val, f!, y, backend::AutoReverseDiff{compile}, x + ) where {compile} _sig = DI.signature(f!, y, backend, x; strict) if compile tape = ReverseDiff.compile(JacobianTape(f!, y, x)) @@ -153,8 +153,8 @@ function DI.prepare_jacobian_nokwarg( end function DI.value_and_jacobian( - f!, y, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f!, y, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f!, y, prep, backend, x) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) @@ -167,8 +167,8 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f!, y, prep, backend, x) result = MutableDiffResult(y, (jac,)) if compile @@ -180,8 +180,8 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!, y, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f!, y, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f!, y, prep, backend, x) if compile jac = jacobian!(prep.tape, x) @@ -192,8 +192,8 @@ function DI.jacobian( end function DI.jacobian!( - f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x -) where {compile} + f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x + ) where {compile} DI.check_prep(f!, y, prep, backend, x) if compile jac = jacobian!(jac, prep.tape, x) @@ -206,21 +206,21 @@ end ### With contexts function DI.prepare_jacobian_nokwarg( - strict::Val, f!, y, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f!, y, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) config = JacobianConfig(y, x) return ReverseDiffTwoArgJacobianPrep(_sig, config, nothing) end function DI.value_and_jacobian( - f!, - y, - prep::ReverseDiffTwoArgJacobianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::ReverseDiffTwoArgJacobianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) @@ -230,14 +230,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!, - y, - jac, - prep::ReverseDiffTwoArgJacobianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + jac, + prep::ReverseDiffTwoArgJacobianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) result = MutableDiffResult(y, (jac,)) @@ -246,13 +246,13 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!, - y, - prep::ReverseDiffTwoArgJacobianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::ReverseDiffTwoArgJacobianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = jacobian(fc!, y, x, prep.config) @@ -260,14 +260,14 @@ function DI.jacobian( end function DI.jacobian!( - f!, - y, - jac, - prep::ReverseDiffTwoArgJacobianPrep, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + jac, + prep::ReverseDiffTwoArgJacobianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = jacobian!(jac, fc!, y, x, prep.config) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl index 203992008..7893b19c7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl @@ -28,7 +28,7 @@ function ADTypes.jacobian_sparsity(f, x, detector::DI.DenseSparsityDetector{:ite if DI.pushforward_performance(backend) isa DI.PushforwardFast p = similar(y) prep = DI.prepare_pushforward_same_point( - f, backend, x, (DI.basis(x, first(eachindex(x))),); strict=Val(true) + f, backend, x, (DI.basis(x, first(eachindex(x))),); strict = Val(true) ) for (kj, j) in enumerate(eachindex(x)) pushforward!(f, (p,), prep, backend, x, (DI.basis(x, j),)) @@ -42,7 +42,7 @@ function ADTypes.jacobian_sparsity(f, x, detector::DI.DenseSparsityDetector{:ite else p = similar(x) prep = DI.prepare_pullback_same_point( - f, backend, x, (DI.basis(y, first(eachindex(y))),); strict=Val(true) + f, backend, x, (DI.basis(y, first(eachindex(y))),); strict = Val(true) ) for (ki, i) in enumerate(eachindex(y)) pullback!(f, (p,), prep, backend, x, (DI.basis(y, i),)) @@ -64,7 +64,7 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DI.DenseSparsityDetector{ if DI.pushforward_performance(backend) isa DI.PushforwardFast p = similar(y) prep = DI.prepare_pushforward_same_point( - f!, y, backend, x, (DI.basis(x, first(eachindex(x))),); strict=Val(true) + f!, y, backend, x, (DI.basis(x, first(eachindex(x))),); strict = Val(true) ) for (kj, j) in enumerate(eachindex(x)) pushforward!(f!, y, (p,), prep, backend, x, (DI.basis(x, j),)) @@ -78,7 +78,7 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DI.DenseSparsityDetector{ else p = similar(x) prep = DI.prepare_pullback_same_point( - f!, y, backend, x, (DI.basis(y, first(eachindex(y))),); strict=Val(true) + f!, y, backend, x, (DI.basis(y, first(eachindex(y))),); strict = Val(true) ) for (ki, i) in enumerate(eachindex(y)) pullback!(f!, y, (p,), prep, backend, x, (DI.basis(y, i),)) @@ -99,7 +99,7 @@ function ADTypes.hessian_sparsity(f, x, detector::DI.DenseSparsityDetector{:iter I, J = Int[], Int[] p = similar(x) prep = DI.prepare_hvp_same_point( - f, backend, x, (DI.basis(x, first(eachindex(x))),); strict=Val(true) + f, backend, x, (DI.basis(x, first(eachindex(x))),); strict = Val(true) ) for (kj, j) in enumerate(eachindex(x)) hvp!(f, (p,), prep, backend, x, (DI.basis(x, j),)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl index 2bb96701a..05e50fb82 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl @@ -5,14 +5,14 @@ import DifferentiationInterface as DI using SparseConnectivityTracer: TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer -@inline function _translate(::Type, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}) +@inline function _translate(::Type, c::Union{DI.GeneralizedConstant, DI.ConstantOrCache}) return DI.unwrap(c) end @inline function _translate(::Type{T}, c::DI.Cache) where {T} return DI.recursive_similar(DI.unwrap(c), T) end -function jacobian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C} +function jacobian_translate(detector, x, contexts::Vararg{DI.Context, C}) where {C} T = eltype(jacobian_buffer(x, detector)) new_contexts = map(contexts) do c _translate(T, c) @@ -20,7 +20,7 @@ function jacobian_translate(detector, x, contexts::Vararg{DI.Context,C}) where { return new_contexts end -function hessian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C} +function hessian_translate(detector, x, contexts::Vararg{DI.Context, C}) where {C} T = eltype(hessian_buffer(x, detector)) new_contexts = map(contexts) do c _translate(T, c) @@ -29,34 +29,34 @@ function hessian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C end function DI.jacobian_sparsity_with_contexts( - f::F, - detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector}, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + detector::Union{TracerSparsityDetector, TracerLocalSparsityDetector}, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} contexts_tracer = jacobian_translate(detector, x, contexts...) fc = DI.fix_tail(f, contexts_tracer...) return jacobian_sparsity(fc, x, detector) end function DI.jacobian_sparsity_with_contexts( - f!::F, - y, - detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector}, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + detector::Union{TracerSparsityDetector, TracerLocalSparsityDetector}, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} contexts_tracer = jacobian_translate(detector, x, contexts...) fc! = DI.fix_tail(f!, contexts_tracer...) return jacobian_sparsity(fc!, y, x, detector) end function DI.hessian_sparsity_with_contexts( - f::F, - detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector}, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + detector::Union{TracerSparsityDetector, TracerLocalSparsityDetector}, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} contexts_tracer = hessian_translate(detector, x, contexts...) fc = DI.fix_tail(f, contexts_tracer...) return hessian_sparsity(fc, x, detector) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index f3925c89d..495b31b10 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -1,14 +1,14 @@ struct SMCSparseHessianPrep{ - SIG, - BS<:DI.BatchSizeSettings, - P<:AbstractMatrix, - C<:AbstractColoringResult{:symmetric,:column}, - M<:AbstractMatrix{<:Number}, - S<:AbstractVector{<:NTuple}, - R<:AbstractVector{<:NTuple}, - E2<:DI.HVPPrep, - E1<:DI.GradientPrep, -} <: DI.SparseHessianPrep{SIG} + SIG, + BS <: DI.BatchSizeSettings, + P <: AbstractMatrix, + C <: AbstractColoringResult{:symmetric, :column}, + M <: AbstractMatrix{<:Number}, + S <: AbstractVector{<:NTuple}, + R <: AbstractVector{<:NTuple}, + E2 <: DI.HVPPrep, + E1 <: DI.GradientPrep, + } <: DI.SparseHessianPrep{SIG} _sig::Val{SIG} batch_size_settings::BS sparsity::P @@ -23,15 +23,15 @@ end ## Hessian, one argument function DI.prepare_hessian_nokwarg( - strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} -) where {F,C} + strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context, C} + ) where {F, C} dense_backend = dense_ad(backend) sparsity = DI.hessian_sparsity_with_contexts( f, sparsity_detector(backend), x, contexts... ) - problem = ColoringProblem{:symmetric,:column}() + problem = ColoringProblem{:symmetric, :column}() coloring_result = coloring( - sparsity, problem, coloring_algorithm(backend); decompression_eltype=eltype(x) + sparsity, problem, coloring_algorithm(backend); decompression_eltype = eltype(x) ) N = length(column_groups(coloring_result)) batch_size_settings = DI.pick_batchsize(DI.outer(dense_backend), N) @@ -41,21 +41,21 @@ function DI.prepare_hessian_nokwarg( end function _prepare_sparse_hessian_aux( - strict::Val, - batch_size_settings::DI.BatchSizeSettings{B}, - sparsity::AbstractMatrix, - coloring_result::AbstractColoringResult{:symmetric,:column}, - f::F, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}; -) where {B,F,C} + strict::Val, + batch_size_settings::DI.BatchSizeSettings{B}, + sparsity::AbstractMatrix, + coloring_result::AbstractColoringResult{:symmetric, :column}, + f::F, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C} + ) where {B, F, C} _sig = DI.signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = column_groups(coloring_result) seeds = [DI.multibasis(x, eachindex(x)[group]) for group in groups] - compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2) + compressed_matrix = stack(_ -> vec(similar(x)), groups; dims = 2) batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] @@ -80,13 +80,13 @@ function _prepare_sparse_hessian_aux( end function DI.hessian!( - f::F, - hess, - prep::SMCSparseHessianPrep{SIG,<:DI.BatchSizeSettings{B}}, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,SIG,B,C} + f::F, + hess, + prep::SMCSparseHessianPrep{SIG, <:DI.BatchSizeSettings{B}}, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, SIG, B, C} DI.check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, @@ -127,22 +127,22 @@ function DI.hessian!( end function DI.hessian( - f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} -) where {F,C} + f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) hess = similar(sparsity_pattern(prep), eltype(x)) return DI.hessian!(f, hess, prep, backend, x, contexts...) end function DI.value_gradient_and_hessian!( - f::F, - grad, - hess, - prep::SMCSparseHessianPrep, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + grad, + hess, + prep::SMCSparseHessianPrep, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!( f, grad, prep.gradient_prep, DI.inner(dense_ad(backend)), x, contexts... @@ -152,8 +152,8 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} -) where {F,C} + f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient( f, prep.gradient_prep, DI.inner(dense_ad(backend)), x, contexts... diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 6c0954db7..7e30d1439 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -1,15 +1,15 @@ ## Preparation struct SMCPushforwardSparseJacobianPrep{ - SIG, - BS<:DI.BatchSizeSettings, - P<:AbstractMatrix, - C<:AbstractColoringResult{:nonsymmetric,:column}, - M<:AbstractMatrix{<:Number}, - S<:AbstractVector{<:NTuple}, - R<:AbstractVector{<:NTuple}, - E<:DI.PushforwardPrep, -} <: SMCSparseJacobianPrep{SIG} + SIG, + BS <: DI.BatchSizeSettings, + P <: AbstractMatrix, + C <: AbstractColoringResult{:nonsymmetric, :column}, + M <: AbstractMatrix{<:Number}, + S <: AbstractVector{<:NTuple}, + R <: AbstractVector{<:NTuple}, + E <: DI.PushforwardPrep, + } <: SMCSparseJacobianPrep{SIG} _sig::Val{SIG} batch_size_settings::BS sparsity::P @@ -21,15 +21,15 @@ struct SMCPushforwardSparseJacobianPrep{ end struct SMCPullbackSparseJacobianPrep{ - SIG, - BS<:DI.BatchSizeSettings, - P<:AbstractMatrix, - C<:AbstractColoringResult{:nonsymmetric,:row}, - M<:AbstractMatrix{<:Number}, - S<:AbstractVector{<:NTuple}, - R<:AbstractVector{<:NTuple}, - E<:DI.PullbackPrep, -} <: SMCSparseJacobianPrep{SIG} + SIG, + BS <: DI.BatchSizeSettings, + P <: AbstractMatrix, + C <: AbstractColoringResult{:nonsymmetric, :row}, + M <: AbstractMatrix{<:Number}, + S <: AbstractVector{<:NTuple}, + R <: AbstractVector{<:NTuple}, + E <: DI.PullbackPrep, + } <: SMCSparseJacobianPrep{SIG} _sig::Val{SIG} batch_size_settings::BS sparsity::P @@ -41,8 +41,8 @@ struct SMCPullbackSparseJacobianPrep{ end function DI.prepare_jacobian_nokwarg( - strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} -) where {F,C} + strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context, C} + ) where {F, C} dense_backend = dense_ad(backend) y = f(x, map(DI.unwrap, contexts)...) perf = DI.pushforward_performance(dense_backend) @@ -50,36 +50,36 @@ function DI.prepare_jacobian_nokwarg( end function DI.prepare_jacobian_nokwarg( - strict::Val, f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} -) where {F,C} + strict::Val, f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context, C} + ) where {F, C} dense_backend = dense_ad(backend) perf = DI.pushforward_performance(dense_backend) return _prepare_sparse_jacobian_aux(strict, perf, y, (f!, y), backend, x, contexts...) end function _prepare_sparse_jacobian_aux( - strict::Val, - perf::DI.PushforwardPerformance, - y, - f_or_f!y::FY, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}; -) where {FY,C} + strict::Val, + perf::DI.PushforwardPerformance, + y, + f_or_f!y::FY, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C} + ) where {FY, C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( f_or_f!y..., sparsity_detector(backend), x, contexts... ) if perf isa DI.PushforwardFast - problem = ColoringProblem{:nonsymmetric,:column}() + problem = ColoringProblem{:nonsymmetric, :column}() else - problem = ColoringProblem{:nonsymmetric,:row}() + problem = ColoringProblem{:nonsymmetric, :row}() end coloring_result = coloring( sparsity, problem, coloring_algorithm(backend); - decompression_eltype=promote_type(eltype(x), eltype(y)), + decompression_eltype = promote_type(eltype(x), eltype(y)), ) if perf isa DI.PushforwardFast N = length(column_groups(coloring_result)) @@ -101,22 +101,22 @@ function _prepare_sparse_jacobian_aux( end function _prepare_sparse_jacobian_aux_aux( - strict::Val, - batch_size_settings::DI.BatchSizeSettings{B}, - sparsity::AbstractMatrix, - coloring_result::AbstractColoringResult{:nonsymmetric,:column}, - y, - f_or_f!y::FY, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}; -) where {B,FY,C} + strict::Val, + batch_size_settings::DI.BatchSizeSettings{B}, + sparsity::AbstractMatrix, + coloring_result::AbstractColoringResult{:nonsymmetric, :column}, + y, + f_or_f!y::FY, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C} + ) where {B, FY, C} _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = column_groups(coloring_result) seeds = [DI.multibasis(x, eachindex(x)[group]) for group in groups] - compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2) + compressed_matrix = stack(_ -> vec(similar(y)), groups; dims = 2) batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] @@ -137,22 +137,22 @@ function _prepare_sparse_jacobian_aux_aux( end function _prepare_sparse_jacobian_aux_aux( - strict::Val, - batch_size_settings::DI.BatchSizeSettings{B}, - sparsity::AbstractMatrix, - coloring_result::AbstractColoringResult{:nonsymmetric,:row}, - y, - f_or_f!y::FY, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}; -) where {B,FY,C} + strict::Val, + batch_size_settings::DI.BatchSizeSettings{B}, + sparsity::AbstractMatrix, + coloring_result::AbstractColoringResult{:nonsymmetric, :row}, + y, + f_or_f!y::FY, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C} + ) where {B, FY, C} _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = row_groups(coloring_result) seeds = [DI.multibasis(y, eachindex(y)[group]) for group in groups] - compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1) + compressed_matrix = stack(_ -> vec(similar(x)), groups; dims = 1) batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] @@ -175,89 +175,89 @@ end ## One argument function DI.jacobian!( - f::F, - jac, - prep::SMCSparseJacobianPrep, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + jac, + prep::SMCSparseJacobianPrep, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) return _sparse_jacobian_aux!((f,), jac, prep, backend, x, contexts...) end function DI.jacobian( - f::F, - prep::SMCSparseJacobianPrep, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::SMCSparseJacobianPrep, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) jac = similar(sparsity_pattern(prep), eltype(x)) return DI.jacobian!(f, jac, prep, backend, x, contexts...) end function DI.value_and_jacobian( - f::F, - prep::SMCSparseJacobianPrep, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + prep::SMCSparseJacobianPrep, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end function DI.value_and_jacobian!( - f::F, - jac, - prep::SMCSparseJacobianPrep, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f::F, + jac, + prep::SMCSparseJacobianPrep, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.jacobian!(f, jac, prep, backend, x, contexts...) + DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Two arguments function DI.jacobian!( - f!::F, - y, - jac, - prep::SMCSparseJacobianPrep, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + jac, + prep::SMCSparseJacobianPrep, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) return _sparse_jacobian_aux!((f!, y), jac, prep, backend, x, contexts...) end function DI.jacobian( - f!::F, - y, - prep::SMCSparseJacobianPrep, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::SMCSparseJacobianPrep, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) jac = similar(sparsity_pattern(prep), promote_type(eltype(x), eltype(y))) return DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) end function DI.value_and_jacobian( - f!::F, - y, - prep::SMCSparseJacobianPrep, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + prep::SMCSparseJacobianPrep, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) jac = DI.jacobian(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -265,14 +265,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!::F, - y, - jac, - prep::SMCSparseJacobianPrep, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} + f!::F, + y, + jac, + prep::SMCSparseJacobianPrep, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {F, C} DI.check_prep(f!, y, prep, backend, x, contexts...) DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -282,13 +282,13 @@ end ## Common auxiliaries function _sparse_jacobian_aux!( - f_or_f!y::FY, - jac, - prep::SMCPushforwardSparseJacobianPrep{SIG,<:DI.BatchSizeSettings{B}}, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {FY,SIG,B,C} + f_or_f!y::FY, + jac, + prep::SMCPushforwardSparseJacobianPrep{SIG, <:DI.BatchSizeSettings{B}}, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {FY, SIG, B, C} (; batch_size_settings, coloring_result, @@ -328,13 +328,13 @@ function _sparse_jacobian_aux!( end function _sparse_jacobian_aux!( - f_or_f!y::FY, - jac, - prep::SMCPullbackSparseJacobianPrep{SIG,<:DI.BatchSizeSettings{B}}, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {FY,SIG,B,C} + f_or_f!y::FY, + jac, + prep::SMCPullbackSparseJacobianPrep{SIG, <:DI.BatchSizeSettings{B}}, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {FY, SIG, B, C} (; batch_size_settings, coloring_result, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index cd611d5cd..97608a535 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -1,22 +1,22 @@ ## Preparation struct SMCMixedModeSparseJacobianPrep{ - SIG, - BSf<:DI.BatchSizeSettings, - BSr<:DI.BatchSizeSettings, - P<:AbstractMatrix, - C<:AbstractColoringResult{:nonsymmetric,:bidirectional}, - Mf<:AbstractMatrix{<:Number}, - Mr<:AbstractMatrix{<:Number}, - Sfp<:NTuple, - Srp<:NTuple, - Sf<:Vector{<:NTuple}, - Sr<:Vector{<:NTuple}, - Rf<:Vector{<:NTuple}, - Rr<:Vector{<:NTuple}, - Ef<:DI.PushforwardPrep, - Er<:DI.PullbackPrep, -} <: SMCSparseJacobianPrep{SIG} + SIG, + BSf <: DI.BatchSizeSettings, + BSr <: DI.BatchSizeSettings, + P <: AbstractMatrix, + C <: AbstractColoringResult{:nonsymmetric, :bidirectional}, + Mf <: AbstractMatrix{<:Number}, + Mr <: AbstractMatrix{<:Number}, + Sfp <: NTuple, + Srp <: NTuple, + Sf <: Vector{<:NTuple}, + Sr <: Vector{<:NTuple}, + Rf <: Vector{<:NTuple}, + Rr <: Vector{<:NTuple}, + Ef <: DI.PushforwardPrep, + Er <: DI.PullbackPrep, + } <: SMCSparseJacobianPrep{SIG} _sig::Val{SIG} batch_size_settings_forward::BSf batch_size_settings_reverse::BSr @@ -35,45 +35,45 @@ struct SMCMixedModeSparseJacobianPrep{ end function DI.prepare_jacobian_nokwarg( - strict::Val, - f::F, - backend::AutoSparse{<:DI.MixedMode}, - x, - contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, + f::F, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context, C} + ) where {F, C} y = f(x, map(DI.unwrap, contexts)...) return _prepare_mixed_sparse_jacobian_aux(strict, y, (f,), backend, x, contexts...) end function DI.prepare_jacobian_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoSparse{<:DI.MixedMode}, - x, - contexts::Vararg{DI.Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context, C} + ) where {F, C} return _prepare_mixed_sparse_jacobian_aux(strict, y, (f!, y), backend, x, contexts...) end function _prepare_mixed_sparse_jacobian_aux( - strict::Val, - y, - f_or_f!y::FY, - backend::AutoSparse{<:DI.MixedMode}, - x, - contexts::Vararg{DI.Context,C}; -) where {FY,C} + strict::Val, + y, + f_or_f!y::FY, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context, C} + ) where {FY, C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( f_or_f!y..., sparsity_detector(backend), x, contexts... ) - problem = ColoringProblem{:nonsymmetric,:bidirectional}() + problem = ColoringProblem{:nonsymmetric, :bidirectional}() coloring_result = coloring( sparsity, problem, coloring_algorithm(backend); - decompression_eltype=promote_type(eltype(x), eltype(y)), + decompression_eltype = promote_type(eltype(x), eltype(y)), ) Nf = length(column_groups(coloring_result)) @@ -91,22 +91,22 @@ function _prepare_mixed_sparse_jacobian_aux( f_or_f!y, backend, x, - contexts...; + contexts... ) end function _prepare_mixed_sparse_jacobian_aux_aux( - strict::Val, - batch_size_settings_forward::DI.BatchSizeSettings{Bf}, - batch_size_settings_reverse::DI.BatchSizeSettings{Br}, - sparsity::AbstractMatrix, - coloring_result::AbstractColoringResult{:nonsymmetric,:bidirectional}, - y, - f_or_f!y::FY, - backend::AutoSparse{<:DI.MixedMode}, - x, - contexts::Vararg{DI.Context,C}; -) where {Bf,Br,FY,C} + strict::Val, + batch_size_settings_forward::DI.BatchSizeSettings{Bf}, + batch_size_settings_reverse::DI.BatchSizeSettings{Br}, + sparsity::AbstractMatrix, + coloring_result::AbstractColoringResult{:nonsymmetric, :bidirectional}, + y, + f_or_f!y::FY, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context, C} + ) where {Bf, Br, FY, C} _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) Nf, Af = batch_size_settings_forward.N, batch_size_settings_forward.A Nr, Ar = batch_size_settings_reverse.N, batch_size_settings_reverse.A @@ -124,12 +124,12 @@ function _prepare_mixed_sparse_jacobian_aux_aux( compressed_matrix_forward = if isempty(groups_forward) similar(vec(y), length(y), 0) else - stack(_ -> vec(similar(y)), groups_forward; dims=2) + stack(_ -> vec(similar(y)), groups_forward; dims = 2) end compressed_matrix_reverse = if isempty(groups_reverse) similar(vec(x), 0, length(x)) else - stack(_ -> vec(similar(x)), groups_reverse; dims=1) + stack(_ -> vec(similar(x)), groups_reverse; dims = 1) end batched_seed_forward_prep = ntuple(b -> copy(seed_forward_prep), Val(Bf)) @@ -154,7 +154,7 @@ function _prepare_mixed_sparse_jacobian_aux_aux( DI.forward_backend(dense_backend), x, batched_seed_forward_prep, - contexts...; + contexts... ) pullback_prep = DI.prepare_pullback_nokwarg( strict, @@ -162,7 +162,7 @@ function _prepare_mixed_sparse_jacobian_aux_aux( DI.reverse_backend(dense_backend), x, batched_seed_reverse_prep, - contexts...; + contexts... ) return SMCMixedModeSparseJacobianPrep( @@ -187,15 +187,15 @@ end ## Common auxiliaries function _sparse_jacobian_aux!( - f_or_f!y::FY, - jac, - prep::SMCMixedModeSparseJacobianPrep{ - SIG,<:DI.BatchSizeSettings{Bf},<:DI.BatchSizeSettings{Br} - }, - backend::AutoSparse, - x, - contexts::Vararg{DI.Context,C}, -) where {FY,SIG,Bf,Br,C} + f_or_f!y::FY, + jac, + prep::SMCMixedModeSparseJacobianPrep{ + SIG, <:DI.BatchSizeSettings{Bf}, <:DI.BatchSizeSettings{Br}, + }, + backend::AutoSparse, + x, + contexts::Vararg{DI.Context, C}, + ) where {FY, SIG, Bf, Br, C} (; batch_size_settings_forward, batch_size_settings_reverse, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl index 43c5abff4..2c48fc981 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl @@ -4,31 +4,31 @@ using ADTypes: AutoForwardDiff, AutoEnzyme import DifferentiationInterface as DI using StaticArrays: SArray, StaticArray -function DI.stack_vec_col(t::NTuple{B,<:StaticArray}) where {B} +function DI.stack_vec_col(t::NTuple{B, <:StaticArray}) where {B} return hcat(map(vec, t)...) end -function DI.stack_vec_row(t::NTuple{B,<:StaticArray}) where {B} +function DI.stack_vec_row(t::NTuple{B, <:StaticArray}) where {B} return vcat(transpose.(map(vec, t))...) end DI.ismutable_array(::Type{<:SArray}) = false function DI.pick_batchsize(::DI.AutoSimpleFiniteDiff{nothing}, x::StaticArray) - return DI.BatchSizeSettings{length(x),true,true}(length(x)) + return DI.BatchSizeSettings{length(x), true, true}(length(x)) end function DI.pick_batchsize(::AutoForwardDiff{nothing}, x::StaticArray) - return DI.BatchSizeSettings{length(x),true,true}(length(x)) + return DI.BatchSizeSettings{length(x), true, true}(length(x)) end function DI.pick_batchsize(::AutoEnzyme, x::StaticArray) - return DI.BatchSizeSettings{length(x),true,true}(length(x)) + return DI.BatchSizeSettings{length(x), true, true}(length(x)) end function DI.pick_batchsize( - ::DI.AutoSimpleFiniteDiff{chunksize}, x::StaticArray -) where {chunksize} + ::DI.AutoSimpleFiniteDiff{chunksize}, x::StaticArray + ) where {chunksize} return DI.BatchSizeSettings{chunksize}(Val(length(x))) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 1179861e6..e34424d25 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -27,7 +27,7 @@ dense_ad(backend::AutoSparse{<:AutoSymbolics}) = ADTypes.dense_ad(backend) variablize(::Number, name::Symbol) = variable(name) variablize(x::AbstractArray, name::Symbol) = variables(name, axes(x)...) -function variablize(contexts::NTuple{C,DI.Context}) where {C} +function variablize(contexts::NTuple{C, DI.Context}) where {C} return ntuple(Val(C)) do k c = contexts[k] variablize(DI.unwrap(c), Symbol("context$k")) @@ -35,14 +35,15 @@ function variablize(contexts::NTuple{C,DI.Context}) where {C} end function erase_cache_vars!( - context_vars::NTuple{C}, contexts::NTuple{C,DI.Context} -) where {C} + context_vars::NTuple{C}, contexts::NTuple{C, DI.Context} + ) where {C} # erase the active data from caches before building function for (v, c) in zip(context_vars, contexts) if c isa DI.Cache fill!(v, zero(eltype(v))) end end + return end include("onearg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 720a8ad47..94079fcce 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -1,14 +1,14 @@ ## Pushforward -struct SymbolicsOneArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} +struct SymbolicsOneArgPushforwardPrep{SIG, E1, E1!} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} pf_exe::E1 pf_exe!::E1! end function DI.prepare_pushforward_nokwarg( - strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) dx = first(tx) x_var = variablize(x, :x) @@ -20,7 +20,7 @@ function DI.prepare_pushforward_nokwarg( erase_cache_vars!(context_vars, contexts) res = build_function( - pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true + pf_var, x_var, dx_var, context_vars...; expression = Val(false), cse = true ) (pf_exe, pf_exe!) = if res isa Tuple res @@ -31,13 +31,13 @@ function DI.prepare_pushforward_nokwarg( end function DI.pushforward( - f, - prep::SymbolicsOneArgPushforwardPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) ty = map(tx) do dx dy = prep.pf_exe(x, dx, map(DI.unwrap, contexts)...) @@ -46,14 +46,14 @@ function DI.pushforward( end function DI.pushforward!( - f, - ty::NTuple, - prep::SymbolicsOneArgPushforwardPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + ty::NTuple, + prep::SymbolicsOneArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -63,50 +63,50 @@ function DI.pushforward!( end function DI.value_and_pushforward( - f, - prep::SymbolicsOneArgPushforwardPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.pushforward(f, prep, backend, x, tx, contexts...) + DI.pushforward(f, prep, backend, x, tx, contexts...) end function DI.value_and_pushforward!( - f, - ty::NTuple, - prep::SymbolicsOneArgPushforwardPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + ty::NTuple, + prep::SymbolicsOneArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) + DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) end ## Derivative -struct SymbolicsOneArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} +struct SymbolicsOneArgDerivativePrep{SIG, E1, E1!} <: DI.DerivativePrep{SIG} _sig::Val{SIG} der_exe::E1 der_exe!::E1! end function DI.prepare_derivative_nokwarg( - strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) der_var = derivative(f(x_var, context_vars...), x_var) erase_cache_vars!(context_vars, contexts) - res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true) + res = build_function(der_var, x_var, context_vars...; expression = Val(false), cse = true) (der_exe, der_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction @@ -116,65 +116,65 @@ function DI.prepare_derivative_nokwarg( end function DI.derivative( - f, - prep::SymbolicsOneArgDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return prep.der_exe(x, map(DI.unwrap, contexts)...) end function DI.derivative!( - f, - der, - prep::SymbolicsOneArgDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der, + prep::SymbolicsOneArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.der_exe!(der, x, map(DI.unwrap, contexts)...) return der end function DI.value_and_derivative( - f, - prep::SymbolicsOneArgDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.derivative(f, prep, backend, x, contexts...) + DI.derivative(f, prep, backend, x, contexts...) end function DI.value_and_derivative!( - f, - der, - prep::SymbolicsOneArgDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der, + prep::SymbolicsOneArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.derivative!(f, der, prep, backend, x, contexts...) + DI.derivative!(f, der, prep, backend, x, contexts...) end ## Gradient -struct SymbolicsOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} +struct SymbolicsOneArgGradientPrep{SIG, E1, E1!} <: DI.GradientPrep{SIG} _sig::Val{SIG} grad_exe::E1 grad_exe!::E1! end function DI.prepare_gradient_nokwarg( - strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -183,63 +183,63 @@ function DI.prepare_gradient_nokwarg( erase_cache_vars!(context_vars, contexts) res = build_function( - grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true + grad_var, vec(x_var), context_vars...; expression = Val(false), cse = true ) (grad_exe, grad_exe!) = res return SymbolicsOneArgGradientPrep(_sig, grad_exe, grad_exe!) end function DI.gradient( - f, - prep::SymbolicsOneArgGradientPrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgGradientPrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return reshape(prep.grad_exe(vec(x), map(DI.unwrap, contexts)...), size(x)) end function DI.gradient!( - f, - grad, - prep::SymbolicsOneArgGradientPrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::SymbolicsOneArgGradientPrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.grad_exe!(vec(grad), vec(x), map(DI.unwrap, contexts)...) return grad end function DI.value_and_gradient( - f, - prep::SymbolicsOneArgGradientPrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgGradientPrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end function DI.value_and_gradient!( - f, - grad, - prep::SymbolicsOneArgGradientPrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + prep::SymbolicsOneArgGradientPrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.gradient!(f, grad, prep, backend, x, contexts...) + DI.gradient!(f, grad, prep, backend, x, contexts...) end ## Jacobian -struct SymbolicsOneArgJacobianPrep{SIG,P,E1,E1!} <: DI.SparseJacobianPrep{SIG} +struct SymbolicsOneArgJacobianPrep{SIG, P, E1, E1!} <: DI.SparseJacobianPrep{SIG} _sig::Val{SIG} sparsity::P jac_exe::E1 @@ -247,12 +247,12 @@ struct SymbolicsOneArgJacobianPrep{SIG,P,E1,E1!} <: DI.SparseJacobianPrep{SIG} end function DI.prepare_jacobian_nokwarg( - strict::Val, - f, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -265,62 +265,62 @@ function DI.prepare_jacobian_nokwarg( end erase_cache_vars!(context_vars, contexts) - res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true) + res = build_function(jac_var, x_var, context_vars...; expression = Val(false), cse = true) (jac_exe, jac_exe!) = res return SymbolicsOneArgJacobianPrep(_sig, sparsity, jac_exe, jac_exe!) end function DI.jacobian( - f, - prep::SymbolicsOneArgJacobianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgJacobianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return prep.jac_exe(x, map(DI.unwrap, contexts)...) end function DI.jacobian!( - f, - jac, - prep::SymbolicsOneArgJacobianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::SymbolicsOneArgJacobianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...) return jac end function DI.value_and_jacobian( - f, - prep::SymbolicsOneArgJacobianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgJacobianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end function DI.value_and_jacobian!( - f, - jac, - prep::SymbolicsOneArgJacobianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + jac, + prep::SymbolicsOneArgJacobianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), - DI.jacobian!(f, jac, prep, backend, x, contexts...) + DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Hessian -struct SymbolicsOneArgHessianPrep{SIG,G,P,E2,E2!} <: DI.SparseHessianPrep{SIG} +struct SymbolicsOneArgHessianPrep{SIG, G, P, E2, E2!} <: DI.SparseHessianPrep{SIG} _sig::Val{SIG} gradient_prep::G sparsity::P @@ -329,12 +329,12 @@ struct SymbolicsOneArgHessianPrep{SIG,G,P,E2,E2!} <: DI.SparseHessianPrep{SIG} end function DI.prepare_hessian_nokwarg( - strict::Val, - f, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -349,7 +349,7 @@ function DI.prepare_hessian_nokwarg( erase_cache_vars!(context_vars, contexts) res = build_function( - hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true + hess_var, vec(x_var), context_vars...; expression = Val(false), cse = true ) (hess_exe, hess_exe!) = res @@ -360,36 +360,36 @@ function DI.prepare_hessian_nokwarg( end function DI.hessian( - f, - prep::SymbolicsOneArgHessianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgHessianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return prep.hess_exe(vec(x), map(DI.unwrap, contexts)...) end function DI.hessian!( - f, - hess, - prep::SymbolicsOneArgHessianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + hess, + prep::SymbolicsOneArgHessianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.hess_exe!(hess, vec(x), map(DI.unwrap, contexts)...) return hess end function DI.value_gradient_and_hessian( - f, - prep::SymbolicsOneArgHessianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgHessianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient( f, prep.gradient_prep, dense_ad(backend), x, contexts... @@ -399,14 +399,14 @@ function DI.value_gradient_and_hessian( end function DI.value_gradient_and_hessian!( - f, - grad, - hess, - prep::SymbolicsOneArgHessianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + hess, + prep::SymbolicsOneArgHessianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!( f, grad, prep.gradient_prep, dense_ad(backend), x, contexts... @@ -417,7 +417,7 @@ end ## HVP -struct SymbolicsOneArgHVPPrep{SIG,G,E2,E2!} <: DI.HVPPrep{SIG} +struct SymbolicsOneArgHVPPrep{SIG, G, E2, E2!} <: DI.HVPPrep{SIG} _sig::Val{SIG} gradient_prep::G hvp_exe::E2 @@ -425,8 +425,8 @@ struct SymbolicsOneArgHVPPrep{SIG,G,E2,E2!} <: DI.HVPPrep{SIG} end function DI.prepare_hvp_nokwarg( - strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) dx = first(tx) x_var = variablize(x, :x) @@ -442,8 +442,8 @@ function DI.prepare_hvp_nokwarg( vec(x_var), vec(dx_var), context_vars...; - expression=Val(false), - cse=true, + expression = Val(false), + cse = true, ) (hvp_exe, hvp_exe!) = res @@ -452,13 +452,13 @@ function DI.prepare_hvp_nokwarg( end function DI.hvp( - f, - prep::SymbolicsOneArgHVPPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgHVPPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return map(tx) do dx dg_vec = prep.hvp_exe(vec(x), vec(dx), map(DI.unwrap, contexts)...) @@ -467,14 +467,14 @@ function DI.hvp( end function DI.hvp!( - f, - tg::NTuple, - prep::SymbolicsOneArgHVPPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + tg::NTuple, + prep::SymbolicsOneArgHVPPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, tg) dx, dg = tx[b], tg[b] @@ -484,13 +484,13 @@ function DI.hvp!( end function DI.gradient_and_hvp( - f, - prep::SymbolicsOneArgHVPPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgHVPPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) tg = DI.hvp(f, prep, backend, x, tx, contexts...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) @@ -498,15 +498,15 @@ function DI.gradient_and_hvp( end function DI.gradient_and_hvp!( - f, - grad, - tg::NTuple, - prep::SymbolicsOneArgHVPPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + tg::NTuple, + prep::SymbolicsOneArgHVPPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hvp!(f, tg, prep, backend, x, tx, contexts...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) @@ -515,7 +515,7 @@ end ## Second derivative -struct SymbolicsOneArgSecondDerivativePrep{SIG,D,E1,E1!} <: DI.SecondDerivativePrep{SIG} +struct SymbolicsOneArgSecondDerivativePrep{SIG, D, E1, E1!} <: DI.SecondDerivativePrep{SIG} _sig::Val{SIG} derivative_prep::D der2_exe::E1 @@ -523,8 +523,8 @@ struct SymbolicsOneArgSecondDerivativePrep{SIG,D,E1,E1!} <: DI.SecondDerivativeP end function DI.prepare_second_derivative_nokwarg( - strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -532,7 +532,7 @@ function DI.prepare_second_derivative_nokwarg( der2_var = derivative(der_var, x_var) erase_cache_vars!(context_vars, contexts) - res = build_function(der2_var, x_var, context_vars...; expression=Val(false), cse=true) + res = build_function(der2_var, x_var, context_vars...; expression = Val(false), cse = true) (der2_exe, der2_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction @@ -543,36 +543,36 @@ function DI.prepare_second_derivative_nokwarg( end function DI.second_derivative( - f, - prep::SymbolicsOneArgSecondDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgSecondDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return prep.der2_exe(x, map(DI.unwrap, contexts)...) end function DI.second_derivative!( - f, - der2, - prep::SymbolicsOneArgSecondDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der2, + prep::SymbolicsOneArgSecondDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) prep.der2_exe!(der2, x, map(DI.unwrap, contexts)...) return der2 end function DI.value_derivative_and_second_derivative( - f, - prep::SymbolicsOneArgSecondDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::SymbolicsOneArgSecondDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x, contexts...) der2 = DI.second_derivative(f, prep, backend, x, contexts...) @@ -580,14 +580,14 @@ function DI.value_derivative_and_second_derivative( end function DI.value_derivative_and_second_derivative!( - f, - der, - der2, - prep::SymbolicsOneArgSecondDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + der, + der2, + prep::SymbolicsOneArgSecondDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x, contexts...) DI.second_derivative!(f, der2, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 8ec5dc7a6..8247fe6c0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -1,20 +1,20 @@ ## Pushforward -struct SymbolicsTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} +struct SymbolicsTwoArgPushforwardPrep{SIG, E1, E1!} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} pushforward_exe::E1 pushforward_exe!::E1! end function DI.prepare_pushforward_nokwarg( - strict::Val, - f!, - y, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f!, + y, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) dx = first(tx) x_var = variablize(x, :x) @@ -28,21 +28,21 @@ function DI.prepare_pushforward_nokwarg( erase_cache_vars!(context_vars, contexts) res = build_function( - pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true + pf_var, x_var, dx_var, context_vars...; expression = Val(false), cse = true ) (pushforward_exe, pushforward_exe!) = res return SymbolicsTwoArgPushforwardPrep(_sig, pushforward_exe, pushforward_exe!) end function DI.pushforward( - f!, - y, - prep::SymbolicsTwoArgPushforwardPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::SymbolicsTwoArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx dy = prep.pushforward_exe(x, dx, map(DI.unwrap, contexts)...) @@ -51,15 +51,15 @@ function DI.pushforward( end function DI.pushforward!( - f!, - y, - ty::NTuple, - prep::SymbolicsTwoArgPushforwardPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + ty::NTuple, + prep::SymbolicsTwoArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -69,14 +69,14 @@ function DI.pushforward!( end function DI.value_and_pushforward( - f!, - y, - prep::SymbolicsTwoArgPushforwardPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::SymbolicsTwoArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -84,15 +84,15 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - f!, - y, - ty::NTuple, - prep::SymbolicsTwoArgPushforwardPrep, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + ty::NTuple, + prep::SymbolicsTwoArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -101,15 +101,15 @@ end ## Derivative -struct SymbolicsTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} +struct SymbolicsTwoArgDerivativePrep{SIG, E1, E1!} <: DI.DerivativePrep{SIG} _sig::Val{SIG} der_exe::E1 der_exe!::E1! end function DI.prepare_derivative_nokwarg( - strict::Val, f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) y_var = variablize(y, :y) @@ -118,45 +118,45 @@ function DI.prepare_derivative_nokwarg( der_var = derivative(y_var, x_var) erase_cache_vars!(context_vars, contexts) - res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true) + res = build_function(der_var, x_var, context_vars...; expression = Val(false), cse = true) (der_exe, der_exe!) = res return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!) end function DI.derivative( - f!, - y, - prep::SymbolicsTwoArgDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::SymbolicsTwoArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) return prep.der_exe(x, map(DI.unwrap, contexts)...) end function DI.derivative!( - f!, - y, - der, - prep::SymbolicsTwoArgDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + der, + prep::SymbolicsTwoArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) prep.der_exe!(der, x, map(DI.unwrap, contexts)...) return der end function DI.value_and_derivative( - f!, - y, - prep::SymbolicsTwoArgDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::SymbolicsTwoArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) der = DI.derivative(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -164,14 +164,14 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!, - y, - der, - prep::SymbolicsTwoArgDerivativePrep, - backend::AutoSymbolics, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + der, + prep::SymbolicsTwoArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) DI.derivative!(f!, y, der, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -180,7 +180,7 @@ end ## Jacobian -struct SymbolicsTwoArgJacobianPrep{SIG,P,E1,E1!} <: DI.SparseJacobianPrep{SIG} +struct SymbolicsTwoArgJacobianPrep{SIG, P, E1, E1!} <: DI.SparseJacobianPrep{SIG} _sig::Val{SIG} sparsity::P jac_exe::E1 @@ -188,13 +188,13 @@ struct SymbolicsTwoArgJacobianPrep{SIG,P,E1,E1!} <: DI.SparseJacobianPrep{SIG} end function DI.prepare_jacobian_nokwarg( - strict::Val, - f!, - y, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, + f!, + y, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) y_var = variablize(y, :y) @@ -209,45 +209,45 @@ function DI.prepare_jacobian_nokwarg( end erase_cache_vars!(context_vars, contexts) - res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true) + res = build_function(jac_var, x_var, context_vars...; expression = Val(false), cse = true) (jac_exe, jac_exe!) = res return SymbolicsTwoArgJacobianPrep(_sig, sparsity, jac_exe, jac_exe!) end function DI.jacobian( - f!, - y, - prep::SymbolicsTwoArgJacobianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::SymbolicsTwoArgJacobianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) return prep.jac_exe(x, map(DI.unwrap, contexts)...) end function DI.jacobian!( - f!, - y, - jac, - prep::SymbolicsTwoArgJacobianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + jac, + prep::SymbolicsTwoArgJacobianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...) return jac end function DI.value_and_jacobian( - f!, - y, - prep::SymbolicsTwoArgJacobianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + prep::SymbolicsTwoArgJacobianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) jac = DI.jacobian(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) @@ -255,14 +255,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!, - y, - jac, - prep::SymbolicsTwoArgJacobianPrep, - backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, - x, - contexts::Vararg{DI.Context,C}, -) where {C} + f!, + y, + jac, + prep::SymbolicsTwoArgJacobianPrep, + backend::Union{AutoSymbolics, AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index d3e05294d..bae79b164 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -9,46 +9,46 @@ DI.inplace_support(::AutoTracker) = DI.InPlaceNotSupported() ## Pullback -struct TrackerPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} +struct TrackerPullbackPrepSamePoint{SIG, Y, PB} <: DI.PullbackPrep{SIG} _sig::Val{SIG} y::Y pb::PB end function DI.prepare_pullback_nokwarg( - strict::Val, - f, - backend::AutoTracker, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}; -) where {C} + strict::Val, + f, + backend::AutoTracker, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C} + ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( - f, - prep::DI.NoPullbackPrep, - backend::AutoTracker, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoTracker, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) + _sig = DI.signature(f, backend, x, ty, contexts...; strict = DI.is_strict(prep)) y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(_sig, y, pb) end function DI.value_and_pullback( - f, - prep::DI.NoPullbackPrep, - backend::AutoTracker, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoTracker, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = forward(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -58,13 +58,13 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f, - prep::TrackerPullbackPrepSamePoint, - backend::AutoTracker, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::TrackerPullbackPrepSamePoint, + backend::AutoTracker, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; y, pb) = prep tx = map(ty) do dy @@ -74,13 +74,13 @@ function DI.value_and_pullback( end function DI.pullback( - f, - prep::TrackerPullbackPrepSamePoint, - backend::AutoTracker, - x, - ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::TrackerPullbackPrepSamePoint, + backend::AutoTracker, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; pb) = prep tx = map(ty) do dy @@ -92,57 +92,57 @@ end ## Gradient function DI.prepare_gradient_nokwarg( - strict::Val, f, backend::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C}; -) where {C} + strict::Val, f, backend::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoGradientPrep(_sig) end function DI.value_and_gradient( - f, - prep::DI.NoGradientPrep, - backend::AutoTracker, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::DI.NoGradientPrep, + backend::AutoTracker, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) end function DI.gradient( - f, - prep::DI.NoGradientPrep, - backend::AutoTracker, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::DI.NoGradientPrep, + backend::AutoTracker, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) end function DI.value_and_gradient!( - f, - grad, - prep::DI.NoGradientPrep, - backend::AutoTracker, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + grad, + prep::DI.NoGradientPrep, + backend::AutoTracker, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end function DI.gradient!( - f, - grad, - prep::DI.NoGradientPrep, - backend::AutoTracker, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + grad, + prep::DI.NoGradientPrep, + backend::AutoTracker, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 0843c5c05..cf2a0820b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -18,47 +18,47 @@ DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported() translate(c::DI.Context) = DI.unwrap(c) translate(c::DI.Cache{<:AbstractArray}) = Buffer(DI.unwrap(c)) -function translate(c::DI.Cache{<:Union{Tuple,NamedTuple}}) +function translate(c::DI.Cache{<:Union{Tuple, NamedTuple}}) return map(translate, map(DI.Cache, DI.unwrap(c))) end ## Pullback -struct ZygotePullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} +struct ZygotePullbackPrepSamePoint{SIG, Y, PB} <: DI.PullbackPrep{SIG} _sig::Val{SIG} y::Y pb::PB end function DI.prepare_pullback_nokwarg( - strict::Val, f, backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}; -) where {C} + strict::Val, f, backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( - f, - prep::DI.NoPullbackPrep, - backend::AutoZygote, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}; -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C} + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) + _sig = DI.signature(f, backend, x, ty, contexts...; strict = DI.is_strict(prep)) y, pb = pullback(f, x, map(translate, contexts)...) return ZygotePullbackPrepSamePoint(_sig, y, pb) end function DI.value_and_pullback( - f, - prep::DI.NoPullbackPrep, - backend::AutoZygote, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::DI.NoPullbackPrep, + backend::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = pullback(f, x, map(translate, contexts)...) tx = map(ty) do dy @@ -68,13 +68,13 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f, - prep::ZygotePullbackPrepSamePoint, - backend::AutoZygote, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ZygotePullbackPrepSamePoint, + backend::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; y, pb) = prep tx = map(ty) do dy @@ -84,13 +84,13 @@ function DI.value_and_pullback( end function DI.pullback( - f, - prep::ZygotePullbackPrepSamePoint, - backend::AutoZygote, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ZygotePullbackPrepSamePoint, + backend::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) (; pb) = prep tx = map(ty) do dy @@ -102,39 +102,39 @@ end ## Gradient function DI.prepare_gradient_nokwarg( - strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoGradientPrep(_sig) end function DI.value_and_gradient( - f, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + f, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; val, grad) = withgradient(f, x, map(translate, contexts)...) return val, first(grad) end function DI.gradient( - f, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + f, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) grad = gradient(f, x, map(translate, contexts)...) return first(grad) end function DI.value_and_gradient!( - f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end function DI.gradient!( - f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end @@ -142,15 +142,15 @@ end ## Jacobian function DI.prepare_jacobian_nokwarg( - strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoJacobianPrep(_sig) end function DI.value_and_jacobian( - f, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + f, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(translate, contexts)...) # https://github.com/FluxML/Zygote.jl/issues/1506 @@ -159,24 +159,24 @@ function DI.value_and_jacobian( end function DI.jacobian( - f, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + f, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) jac = jacobian(f, x, map(translate, contexts)...) return first(jac) end function DI.value_and_jacobian!( - f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end function DI.jacobian!( - f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} -) where {C} + f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context, C} + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -185,14 +185,14 @@ end # Beware, this uses ForwardDiff for the inner differentiation -struct ZygoteHVPPrep{SIG,P} <: DI.HVPPrep{SIG} +struct ZygoteHVPPrep{SIG, P} <: DI.HVPPrep{SIG} _sig::Val{SIG} fd_prep::P end function DI.prepare_hvp_nokwarg( - strict::Val, f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} -) where {C} + strict::Val, f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context, C} + ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) fd_prep = DI.prepare_hvp_nokwarg( strict, f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -201,13 +201,13 @@ function DI.prepare_hvp_nokwarg( end function DI.hvp( - f, - prep::ZygoteHVPPrep, - backend::AutoZygote, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ZygoteHVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.hvp( f, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -215,14 +215,14 @@ function DI.hvp( end function DI.hvp!( - f, - tg::NTuple, - prep::ZygoteHVPPrep, - backend::AutoZygote, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + tg::NTuple, + prep::ZygoteHVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.hvp!( f, tg, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -230,13 +230,13 @@ function DI.hvp!( end function DI.gradient_and_hvp( - f, - prep::ZygoteHVPPrep, - backend::AutoZygote, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + prep::ZygoteHVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.gradient_and_hvp( f, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -244,15 +244,15 @@ function DI.gradient_and_hvp( end function DI.gradient_and_hvp!( - f, - grad, - tg::NTuple, - prep::ZygoteHVPPrep, - backend::AutoZygote, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} + f, + grad, + tg::NTuple, + prep::ZygoteHVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.gradient_and_hvp!( f, @@ -269,19 +269,19 @@ end ## Hessian function DI.prepare_hessian_nokwarg( - strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} -) where {C} + strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant, C} + ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoHessianPrep(_sig) end function DI.hessian( - f, - prep::DI.NoHessianPrep, - backend::AutoZygote, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) hess = hessian(fc, x) @@ -289,24 +289,24 @@ function DI.hessian( end function DI.hessian!( - f, - hess, - prep::DI.NoHessianPrep, - backend::AutoZygote, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + hess, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return copyto!(hess, DI.hessian(f, prep, backend, x, contexts...)) end function DI.value_gradient_and_hessian( - f, - prep::DI.NoHessianPrep, - backend::AutoZygote, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient(f, backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) @@ -314,14 +314,14 @@ function DI.value_gradient_and_hessian( end function DI.value_gradient_and_hessian!( - f, - grad, - hess, - prep::DI.NoHessianPrep, - backend::AutoZygote, - x, - contexts::Vararg{DI.GeneralizedConstant,C}, -) where {C} + f, + grad, + hess, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.GeneralizedConstant, C}, + ) where {C} DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!(f, grad, backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/src/compat.jl b/DifferentiationInterface/src/compat.jl index db72a3239..ec3076827 100644 --- a/DifferentiationInterface/src/compat.jl +++ b/DifferentiationInterface/src/compat.jl @@ -1,5 +1,5 @@ macro public(ex) - if VERSION >= v"1.11.0-DEV.469" + return if VERSION >= v"1.11.0-DEV.469" args = if ex isa Symbol (ex,) elseif Base.isexpr(ex, :tuple) diff --git a/DifferentiationInterface/src/docstrings.jl b/DifferentiationInterface/src/docstrings.jl index ff3359884..ab5e5d587 100644 --- a/DifferentiationInterface/src/docstrings.jl +++ b/DifferentiationInterface/src/docstrings.jl @@ -7,14 +7,14 @@ function docstring_preptype(preptype::AbstractString, operator::AbstractString) end function samepoint_warning(samepoint::Bool) - if samepoint + return if samepoint ", _if they are applied at the same point `x` and with the same `contexts`_" else "" end end -function docstring_prepare(operator; samepoint=false, inplace=false) +function docstring_prepare(operator; samepoint = false, inplace = false) return """ Create a `prep` object that can be given to [`$(operator)`](@ref) and its variants to speed them up$(samepoint_warning(samepoint)). @@ -49,7 +49,7 @@ function docstring_prepare!(operator) """ end -function docstring_preparation_hint(operator::AbstractString; same_point=false) +function docstring_preparation_hint(operator::AbstractString; same_point = false) if same_point return "To improve performance via operator preparation, refer to [`prepare_$(operator)`](@ref) and [`prepare_$(operator)_same_point`](@ref)." else diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index ca2314c97..45cde2b6c 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -4,17 +4,17 @@ prepare_derivative(f, backend, x, [contexts...]; strict=Val(true)) -> prep prepare_derivative(f!, y, backend, x, [contexts...]; strict=Val(true)) -> prep -$(docstring_prepare("derivative"; inplace=true)) +$(docstring_prepare("derivative"; inplace = true)) """ function prepare_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C}; strict::Val = Val(true) + ) where {F, C} return prepare_derivative_nokwarg(strict, f, backend, x, contexts...) end function prepare_derivative( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) -) where {F,C} + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context, C}; strict::Val = Val(true) + ) where {F, C} return prepare_derivative_nokwarg(strict, f!, y, backend, x, contexts...) end @@ -25,20 +25,20 @@ end $(docstring_prepare!("derivative")) """ function prepare!_derivative( - f::F, old_prep::DerivativePrep, backend::AbstractADType, x, contexts::Vararg{Context,C}; -) where {F,C} + f::F, old_prep::DerivativePrep, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} check_prep(f, old_prep, backend, x, contexts...) return prepare_derivative_nokwarg(is_strict(old_prep), f, backend, x, contexts...) end function prepare!_derivative( - f!::F, - y, - old_prep::DerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + old_prep::DerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, old_prep, backend, x, contexts...) return prepare_derivative_nokwarg(is_strict(old_prep), f!, y, backend, x, contexts...) end @@ -52,15 +52,15 @@ Compute the value and the derivative of the function `f` at point `x`. $(docstring_preparation_hint("derivative")) """ function value_and_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_derivative_nokwarg(Val(true), f, backend, x, contexts...) return value_and_derivative(f, prep, backend, x, contexts...) end function value_and_derivative( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return value_and_derivative(f!, y, prep, backend, x, contexts...) end @@ -74,15 +74,15 @@ Compute the value and the derivative of the function `f` at point `x`, overwriti $(docstring_preparation_hint("derivative")) """ function value_and_derivative!( - f::F, der, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, der, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_derivative_nokwarg(Val(true), f, backend, x, contexts...) return value_and_derivative!(f, der, prep, backend, x, contexts...) end function value_and_derivative!( - f!::F, y, der, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, der, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return value_and_derivative!(f!, y, der, prep, backend, x, contexts...) end @@ -96,15 +96,15 @@ Compute the derivative of the function `f` at point `x`. $(docstring_preparation_hint("derivative")) """ function derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_derivative_nokwarg(Val(true), f, backend, x, contexts...) return derivative(f, prep, backend, x, contexts...) end function derivative( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return derivative(f!, y, prep, backend, x, contexts...) end @@ -118,29 +118,29 @@ Compute the derivative of the function `f` at point `x`, overwriting `der`. $(docstring_preparation_hint("derivative")) """ function derivative!( - f::F, der, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, der, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_derivative_nokwarg(Val(true), f, backend, x, contexts...) return derivative!(f, der, prep, backend, x, contexts...) end function derivative!( - f!::F, y, der, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, der, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return derivative!(f!, y, der, prep, backend, x, contexts...) end ## Preparation -struct PushforwardDerivativePrep{SIG,E<:PushforwardPrep} <: DerivativePrep{SIG} +struct PushforwardDerivativePrep{SIG, E <: PushforwardPrep} <: DerivativePrep{SIG} _sig::Val{SIG} pushforward_prep::E end function prepare_derivative_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, contexts...; strict) pushforward_prep = prepare_pushforward_nokwarg( strict, f, backend, x, (oneunit(x),), contexts... @@ -149,8 +149,8 @@ function prepare_derivative_nokwarg( end function prepare_derivative_nokwarg( - strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f!, y, backend, x, contexts...; strict) pushforward_prep = prepare_pushforward_nokwarg( strict, f!, y, backend, x, (oneunit(x),), contexts... @@ -161,12 +161,12 @@ end ## One argument function value_and_derivative( - f::F, - prep::PushforwardDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::PushforwardDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) y, ty = value_and_pushforward( f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts... @@ -175,13 +175,13 @@ function value_and_derivative( end function value_and_derivative!( - f::F, - der, - prep::PushforwardDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + der, + prep::PushforwardDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) y, _ = value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... @@ -190,25 +190,25 @@ function value_and_derivative!( end function derivative( - f::F, - prep::PushforwardDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::PushforwardDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) ty = pushforward(f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...) return only(ty) end function derivative!( - f::F, - der, - prep::PushforwardDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + der, + prep::PushforwardDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) pushforward!(f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...) return der @@ -217,13 +217,13 @@ end ## Two arguments function value_and_derivative( - f!::F, - y, - prep::PushforwardDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + prep::PushforwardDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, contexts...) y, ty = value_and_pushforward( f!, y, prep.pushforward_prep, backend, x, (oneunit(x),), contexts... @@ -232,14 +232,14 @@ function value_and_derivative( end function value_and_derivative!( - f!::F, - y, - der, - prep::PushforwardDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + der, + prep::PushforwardDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, contexts...) y, _ = value_and_pushforward!( f!, y, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... @@ -248,27 +248,27 @@ function value_and_derivative!( end function derivative( - f!::F, - y, - prep::PushforwardDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + prep::PushforwardDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, contexts...) ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...) return only(ty) end function derivative!( - f!::F, - y, - der, - prep::PushforwardDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + der, + prep::PushforwardDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, contexts...) pushforward!( f!, y, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... @@ -279,7 +279,7 @@ end ## Shuffled function shuffled_derivative( - x, f::F, backend::AbstractADType, rewrap::Rewrap{C}, unannotated_contexts::Vararg{Any,C} -) where {F,C} + x, f::F, backend::AbstractADType, rewrap::Rewrap{C}, unannotated_contexts::Vararg{Any, C} + ) where {F, C} return derivative(f, backend, x, rewrap(unannotated_contexts...)...) end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 2720307c7..9e41872c5 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -6,8 +6,8 @@ $(docstring_prepare("gradient")) """ function prepare_gradient( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C}; strict::Val = Val(true) + ) where {F, C} return prepare_gradient_nokwarg(strict, f, backend, x, contexts...) end @@ -17,8 +17,8 @@ end $(docstring_prepare!("gradient")) """ function prepare!_gradient( - f::F, old_prep::GradientPrep, backend::AbstractADType, x, contexts::Vararg{Context,C}; -) where {F,C} + f::F, old_prep::GradientPrep, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} check_prep(f, old_prep, backend, x, contexts...) return prepare_gradient_nokwarg(is_strict(old_prep), f, backend, x, contexts...) end @@ -31,8 +31,8 @@ Compute the value and the gradient of the function `f` at point `x`. $(docstring_preparation_hint("gradient")) """ function value_and_gradient( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return value_and_gradient(f, prep, backend, x, contexts...) end @@ -45,8 +45,8 @@ Compute the value and the gradient of the function `f` at point `x`, overwriting $(docstring_preparation_hint("gradient")) """ function value_and_gradient!( - f::F, grad, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, grad, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return value_and_gradient!(f, grad, prep, backend, x, contexts...) end @@ -58,7 +58,7 @@ Compute the gradient of the function `f` at point `x`. $(docstring_preparation_hint("gradient")) """ -function gradient(f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}) where {F,C} +function gradient(f::F, backend::AbstractADType, x, contexts::Vararg{Context, C}) where {F, C} prep = prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return gradient(f, prep, backend, x, contexts...) end @@ -71,23 +71,23 @@ Compute the gradient of the function `f` at point `x`, overwriting `grad`. $(docstring_preparation_hint("gradient")) """ function gradient!( - f::F, grad, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, grad, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return gradient!(f, grad, prep, backend, x, contexts...) end ## Preparation -struct PullbackGradientPrep{SIG,Y,E<:PullbackPrep} <: GradientPrep{SIG} +struct PullbackGradientPrep{SIG, Y, E <: PullbackPrep} <: GradientPrep{SIG} _sig::Val{SIG} y::Y pullback_prep::E end function prepare_gradient_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, contexts...; strict) y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference? pullback_prep = prepare_pullback_nokwarg( @@ -99,12 +99,12 @@ end ## One argument function value_and_gradient( - f::F, - prep::PullbackGradientPrep{SIG,Y}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,SIG,Y,C} + f::F, + prep::PullbackGradientPrep{SIG, Y}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, SIG, Y, C} check_prep(f, prep, backend, x, contexts...) y, tx = value_and_pullback( f, prep.pullback_prep, backend, x, (oneunit(Y),), contexts... @@ -113,13 +113,13 @@ function value_and_gradient( end function value_and_gradient!( - f::F, - grad, - prep::PullbackGradientPrep{SIG,Y}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,SIG,Y,C} + f::F, + grad, + prep::PullbackGradientPrep{SIG, Y}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, SIG, Y, C} check_prep(f, prep, backend, x, contexts...) y, _ = value_and_pullback!( f, (grad,), prep.pullback_prep, backend, x, (oneunit(Y),), contexts... @@ -128,25 +128,25 @@ function value_and_gradient!( end function gradient( - f::F, - prep::PullbackGradientPrep{SIG,Y}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,SIG,Y,C} + f::F, + prep::PullbackGradientPrep{SIG, Y}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, SIG, Y, C} check_prep(f, prep, backend, x, contexts...) tx = pullback(f, prep.pullback_prep, backend, x, (oneunit(Y),), contexts...) return only(tx) end function gradient!( - f::F, - grad, - prep::PullbackGradientPrep{SIG,Y}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,SIG,Y,C} + f::F, + grad, + prep::PullbackGradientPrep{SIG, Y}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, SIG, Y, C} check_prep(f, prep, backend, x, contexts...) pullback!(f, (grad,), prep.pullback_prep, backend, x, (oneunit(Y),), contexts...) return grad @@ -155,43 +155,43 @@ end ## Shuffled function shuffled_gradient( - x, f::F, backend::AbstractADType, rewrap::Rewrap{C}, unannotated_contexts::Vararg{Any,C} -) where {F,C} + x, f::F, backend::AbstractADType, rewrap::Rewrap{C}, unannotated_contexts::Vararg{Any, C} + ) where {F, C} return gradient(f, backend, x, rewrap(unannotated_contexts...)...) end function shuffled_gradient( - x, - f::F, - prep::GradientPrep, - backend::AbstractADType, - rewrap::Rewrap{C}, - unannotated_contexts::Vararg{Any,C}, -) where {F,C} + x, + f::F, + prep::GradientPrep, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any, C}, + ) where {F, C} return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...) end function shuffled_gradient!( - grad, - x, - f::F, - backend::AbstractADType, - rewrap::Rewrap{C}, - unannotated_contexts::Vararg{Any,C}, -) where {F,C} + grad, + x, + f::F, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any, C}, + ) where {F, C} gradient!(f, grad, backend, x, rewrap(unannotated_contexts...)...) return nothing end function shuffled_gradient!( - grad, - x, - f::F, - prep::GradientPrep, - backend::AbstractADType, - rewrap::Rewrap{C}, - unannotated_contexts::Vararg{Any,C}, -) where {F,C} + grad, + x, + f::F, + prep::GradientPrep, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any, C}, + ) where {F, C} gradient!(f, grad, prep, backend, x, rewrap(unannotated_contexts...)...) return nothing end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 1b5ce5ccd..f651c2d05 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -4,17 +4,17 @@ prepare_jacobian(f, backend, x, [contexts...]; strict=Val(true)) -> prep prepare_jacobian(f!, y, backend, x, [contexts...]; strict=Val(true)) -> prep -$(docstring_prepare("jacobian"; inplace=true)) +$(docstring_prepare("jacobian"; inplace = true)) """ function prepare_jacobian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C}; strict::Val = Val(true) + ) where {F, C} return prepare_jacobian_nokwarg(strict, f, backend, x, contexts...) end function prepare_jacobian( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) -) where {F,C} + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context, C}; strict::Val = Val(true) + ) where {F, C} return prepare_jacobian_nokwarg(strict, f!, y, backend, x, contexts...) end @@ -25,20 +25,20 @@ end $(docstring_prepare!("jacobian")) """ function prepare!_jacobian( - f::F, old_prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C}; -) where {F,C} + f::F, old_prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} check_prep(f, old_prep, backend, x, contexts...) return prepare_jacobian_nokwarg(is_strict(old_prep), f, backend, x, contexts...) end function prepare!_jacobian( - f!::F, - y, - old_prep::JacobianPrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}; -) where {F,C} + f!::F, + y, + old_prep::JacobianPrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C} + ) where {F, C} check_prep(f!, y, old_prep, backend, x, contexts...) return prepare_jacobian_nokwarg(is_strict(old_prep), f!, y, backend, x, contexts...) end @@ -52,15 +52,15 @@ Compute the value and the Jacobian matrix of the function `f` at point `x`. $(docstring_preparation_hint("jacobian")) """ function value_and_jacobian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return value_and_jacobian(f, prep, backend, x, contexts...) end function value_and_jacobian( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return value_and_jacobian(f!, y, prep, backend, x, contexts...) end @@ -74,15 +74,15 @@ Compute the value and the Jacobian matrix of the function `f` at point `x`, over $(docstring_preparation_hint("jacobian")) """ function value_and_jacobian!( - f::F, jac, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, jac, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return value_and_jacobian!(f, jac, prep, backend, x, contexts...) end function value_and_jacobian!( - f!::F, y, jac, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, jac, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return value_and_jacobian!(f!, y, jac, prep, backend, x, contexts...) end @@ -95,14 +95,14 @@ Compute the Jacobian matrix of the function `f` at point `x`. $(docstring_preparation_hint("jacobian")) """ -function jacobian(f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}) where {F,C} +function jacobian(f::F, backend::AbstractADType, x, contexts::Vararg{Context, C}) where {F, C} prep = prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return jacobian(f, prep, backend, x, contexts...) end function jacobian( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return jacobian(f!, y, prep, backend, x, contexts...) end @@ -116,15 +116,15 @@ Compute the Jacobian matrix of the function `f` at point `x`, overwriting `jac`. $(docstring_preparation_hint("jacobian")) """ function jacobian!( - f::F, jac, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, jac, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return jacobian!(f, jac, prep, backend, x, contexts...) end function jacobian!( - f!::F, y, jac, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, jac, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return jacobian!(f!, y, jac, prep, backend, x, contexts...) end @@ -134,13 +134,13 @@ end abstract type StandardJacobianPrep{SIG} <: JacobianPrep{SIG} end struct PushforwardJacobianPrep{ - SIG, - BS<:BatchSizeSettings, - S<:AbstractVector{<:NTuple}, - R<:AbstractVector{<:NTuple}, - SE<:NTuple, - E<:PushforwardPrep, -} <: StandardJacobianPrep{SIG} + SIG, + BS <: BatchSizeSettings, + S <: AbstractVector{<:NTuple}, + R <: AbstractVector{<:NTuple}, + SE <: NTuple, + E <: PushforwardPrep, + } <: StandardJacobianPrep{SIG} _sig::Val{SIG} batch_size_settings::BS batched_seeds::S @@ -150,13 +150,13 @@ struct PushforwardJacobianPrep{ end struct PullbackJacobianPrep{ - SIG, - BS<:BatchSizeSettings, - S<:AbstractVector{<:NTuple}, - R<:AbstractVector{<:NTuple}, - SE<:NTuple, - E<:PullbackPrep, -} <: StandardJacobianPrep{SIG} + SIG, + BS <: BatchSizeSettings, + S <: AbstractVector{<:NTuple}, + R <: AbstractVector{<:NTuple}, + SE <: NTuple, + E <: PullbackPrep, + } <: StandardJacobianPrep{SIG} _sig::Val{SIG} batch_size_settings::BS batched_seeds::S @@ -166,8 +166,8 @@ struct PullbackJacobianPrep{ end function prepare_jacobian_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} y = f(x, map(unwrap, contexts)...) perf = pushforward_performance(backend) # type-unstable @@ -183,8 +183,8 @@ function prepare_jacobian_nokwarg( end function prepare_jacobian_nokwarg( - strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} perf = pushforward_performance(backend) # type-unstable if perf isa PushforwardFast @@ -199,15 +199,15 @@ function prepare_jacobian_nokwarg( end function _prepare_jacobian_aux( - strict::Val, - ::PushforwardFast, - batch_size_settings::BatchSizeSettings{B}, - y, - f_or_f!y::FY, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}; -) where {B,FY,C} + strict::Val, + ::PushforwardFast, + batch_size_settings::BatchSizeSettings{B}, + y, + f_or_f!y::FY, + backend::AbstractADType, + x, + contexts::Vararg{Context, C} + ) where {B, FY, C} _sig = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(x, ind) for ind in eachindex(x)] @@ -230,15 +230,15 @@ function _prepare_jacobian_aux( end function _prepare_jacobian_aux( - strict::Val, - ::PushforwardSlow, - batch_size_settings::BatchSizeSettings{B}, - y, - f_or_f!y::FY, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}; -) where {B,FY,C} + strict::Val, + ::PushforwardSlow, + batch_size_settings::BatchSizeSettings{B}, + y, + f_or_f!y::FY, + backend::AbstractADType, + x, + contexts::Vararg{Context, C} + ) where {B, FY, C} _sig = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(y, ind) for ind in eachindex(y)] @@ -263,38 +263,38 @@ end ## One argument function jacobian( - f::F, - prep::StandardJacobianPrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::StandardJacobianPrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) return _jacobian_aux((f,), prep, backend, x, contexts...) end function jacobian!( - f::F, - jac, - prep::StandardJacobianPrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + jac, + prep::StandardJacobianPrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) return _jacobian_aux!((f,), jac, prep, backend, x, contexts...) end function value_and_jacobian( - f::F, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} check_prep(f, prep, backend, x, contexts...) return f(x, map(unwrap, contexts)...), jacobian(f, prep, backend, x, contexts...) end function value_and_jacobian!( - f::F, jac, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, jac, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} check_prep(f, prep, backend, x, contexts...) return f(x, map(unwrap, contexts)...), jacobian!(f, jac, prep, backend, x, contexts...) end @@ -302,33 +302,33 @@ end ## Two arguments function jacobian( - f!::F, - y, - prep::StandardJacobianPrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + prep::StandardJacobianPrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, contexts...) return _jacobian_aux((f!, y), prep, backend, x, contexts...) end function jacobian!( - f!::F, - y, - jac, - prep::StandardJacobianPrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + jac, + prep::StandardJacobianPrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, contexts...) return _jacobian_aux!((f!, y), jac, prep, backend, x, contexts...) end function value_and_jacobian( - f!::F, y, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} check_prep(f!, y, prep, backend, x, contexts...) jac = jacobian(f!, y, prep, backend, x, contexts...) f!(y, x, map(unwrap, contexts)...) @@ -336,14 +336,14 @@ function value_and_jacobian( end function value_and_jacobian!( - f!::F, - y, - jac, - prep::JacobianPrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + jac, + prep::JacobianPrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, contexts...) jacobian!(f!, y, jac, prep, backend, x, contexts...) f!(y, x, map(unwrap, contexts)...) @@ -353,12 +353,12 @@ end ## Common auxiliaries function _jacobian_aux( - f_or_f!y::FY, - prep::PushforwardJacobianPrep{SIG,<:BatchSizeSettings{B,true,aligned}}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {FY,SIG,B,aligned,C} + f_or_f!y::FY, + prep::PushforwardJacobianPrep{SIG, <:BatchSizeSettings{B, true, aligned}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {FY, SIG, B, aligned, C} (; batch_size_settings, batched_seeds, pushforward_prep) = prep (; B_last) = batch_size_settings dy_batch = pushforward( @@ -373,12 +373,12 @@ function _jacobian_aux( end function _jacobian_aux( - f_or_f!y::FY, - prep::PushforwardJacobianPrep{SIG,<:BatchSizeSettings{B,false,aligned}}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {FY,SIG,B,aligned,C} + f_or_f!y::FY, + prep::PushforwardJacobianPrep{SIG, <:BatchSizeSettings{B, false, aligned}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {FY, SIG, B, aligned, C} (; batch_size_settings, batched_seeds, seed_example, pushforward_prep) = prep (; A, B_last) = batch_size_settings @@ -406,12 +406,12 @@ function _jacobian_aux( end function _jacobian_aux( - f_or_f!y::FY, - prep::PullbackJacobianPrep{SIG,<:BatchSizeSettings{B,true,aligned}}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {FY,SIG,B,aligned,C} + f_or_f!y::FY, + prep::PullbackJacobianPrep{SIG, <:BatchSizeSettings{B, true, aligned}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {FY, SIG, B, aligned, C} (; batch_size_settings, batched_seeds, pullback_prep) = prep (; B_last) = batch_size_settings dx_batch = pullback( @@ -429,12 +429,12 @@ function _jacobian_aux( end function _jacobian_aux( - f_or_f!y::FY, - prep::PullbackJacobianPrep{SIG,<:BatchSizeSettings{B,false,aligned}}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {FY,SIG,B,aligned,C} + f_or_f!y::FY, + prep::PullbackJacobianPrep{SIG, <:BatchSizeSettings{B, false, aligned}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {FY, SIG, B, aligned, C} (; batch_size_settings, batched_seeds, seed_example, pullback_prep) = prep (; A, B_last) = batch_size_settings @@ -460,15 +460,15 @@ function _jacobian_aux( end function _jacobian_aux!( - f_or_f!y::FY, - jac, - prep::PushforwardJacobianPrep{SIG,<:BatchSizeSettings{B}}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {FY,SIG,B,C} + f_or_f!y::FY, + jac, + prep::PushforwardJacobianPrep{SIG, <:BatchSizeSettings{B}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {FY, SIG, B, C} (; - batch_size_settings, batched_seeds, batched_results, seed_example, pushforward_prep + batch_size_settings, batched_seeds, batched_results, seed_example, pushforward_prep, ) = prep (; N) = batch_size_settings @@ -498,13 +498,13 @@ function _jacobian_aux!( end function _jacobian_aux!( - f_or_f!y::FY, - jac, - prep::PullbackJacobianPrep{SIG,<:BatchSizeSettings{B}}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {FY,SIG,B,C} + f_or_f!y::FY, + jac, + prep::PullbackJacobianPrep{SIG, <:BatchSizeSettings{B}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {FY, SIG, B, C} (; batch_size_settings, batched_seeds, batched_results, seed_example, pullback_prep) = prep (; N) = batch_size_settings diff --git a/DifferentiationInterface/src/first_order/mixed_mode.jl b/DifferentiationInterface/src/first_order/mixed_mode.jl index d15a3102f..7686fe2c7 100644 --- a/DifferentiationInterface/src/first_order/mixed_mode.jl +++ b/DifferentiationInterface/src/first_order/mixed_mode.jl @@ -11,13 +11,13 @@ Combination of a forward and a reverse mode backend for mixed-mode sparse Jacobi MixedMode(forward_backend, reverse_backend) """ -struct MixedMode{F<:AbstractADType,R<:AbstractADType} <: AbstractADType +struct MixedMode{F <: AbstractADType, R <: AbstractADType} <: AbstractADType forward::F reverse::R function MixedMode(forward::AbstractADType, reverse::AbstractADType) @assert pushforward_performance(forward) isa PushforwardFast @assert pullback_performance(reverse) isa PullbackFast - return new{typeof(forward),typeof(reverse)}(forward, reverse) + return new{typeof(forward), typeof(reverse)}(forward, reverse) end end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 7ba3a83df..183b15834 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -4,28 +4,28 @@ prepare_pullback(f, backend, x, ty, [contexts...]; strict=Val(true)) -> prep prepare_pullback(f!, y, backend, x, ty, [contexts...]; strict=Val(true)) -> prep -$(docstring_prepare("pullback"; inplace=true)) +$(docstring_prepare("pullback"; inplace = true)) """ function prepare_pullback( - f::F, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f::F, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_pullback_nokwarg(strict, f, backend, x, ty, contexts...) end function prepare_pullback( - f!::F, - y, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_pullback_nokwarg(strict, f!, y, backend, x, ty, contexts...) end @@ -36,26 +36,26 @@ end $(docstring_prepare!("pullback")) """ function prepare!_pullback( - f::F, - old_prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + old_prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, old_prep, backend, x, ty, contexts...) return prepare_pullback_nokwarg(is_strict(old_prep), f, backend, x, ty, contexts...) end function prepare!_pullback( - f!::F, - y, - old_prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + old_prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, old_prep, backend, x, ty, contexts...) return prepare_pullback_nokwarg(is_strict(old_prep), f!, y, backend, x, ty, contexts...) end @@ -64,72 +64,72 @@ end prepare_pullback_same_point(f, backend, x, ty, [contexts...]; strict=Val(true)) -> prep_same prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]; strict=Val(true)) -> prep_same -$(docstring_prepare("pullback"; samepoint=true, inplace=true)) +$(docstring_prepare("pullback"; samepoint = true, inplace = true)) """ function prepare_pullback_same_point( - f::F, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f::F, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_pullback_same_point_nokwarg(strict, f, backend, x, ty, contexts...) end function prepare_pullback_same_point( - f!::F, - y, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_pullback_same_point_nokwarg(strict, f!, y, backend, x, ty, contexts...) end function prepare_pullback_same_point_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(strict, f, backend, x, ty, contexts...) return prepare_pullback_same_point(f, prep, backend, x, ty, contexts...) end function prepare_pullback_same_point_nokwarg( - strict::Val, - f!::F, - y, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, ty, contexts...) return prepare_pullback_same_point(f!, y, prep, backend, x, ty, contexts...) end function prepare_pullback_same_point( - f::F, - prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, ty, contexts...) return prep end function prepare_pullback_same_point( - f!::F, - y, - prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, ty, contexts...) return prep end @@ -140,7 +140,7 @@ end Compute the value and the pullback of the function `f` at point `x` with a tuple of tangents `ty`. -$(docstring_preparation_hint("pullback"; same_point=true)) +$(docstring_preparation_hint("pullback"; same_point = true)) !!! tip @@ -152,15 +152,15 @@ $(docstring_preparation_hint("pullback"; same_point=true)) Required primitive for reverse mode backends. """ function value_and_pullback( - f::F, backend::AbstractADType, x, ty, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, ty, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(Val(true), f, backend, x, ty, contexts...) return value_and_pullback(f, prep, backend, x, ty, contexts...) end function value_and_pullback( - f!::F, y, backend::AbstractADType, x, ty, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, backend::AbstractADType, x, ty, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(Val(true), f!, y, backend, x, ty, contexts...) return value_and_pullback(f!, y, prep, backend, x, ty, contexts...) end @@ -171,7 +171,7 @@ end Compute the value and the pullback of the function `f` at point `x` with a tuple of tangents `ty`, overwriting `dx`. -$(docstring_preparation_hint("pullback"; same_point=true)) +$(docstring_preparation_hint("pullback"; same_point = true)) !!! tip @@ -179,15 +179,15 @@ $(docstring_preparation_hint("pullback"; same_point=true)) This function could have been named `value_and_vjp!`. """ function value_and_pullback!( - f::F, tx, backend::AbstractADType, x, ty, contexts::Vararg{Context,C} -) where {F,C} + f::F, tx, backend::AbstractADType, x, ty, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(Val(true), f, backend, x, ty, contexts...) return value_and_pullback!(f, tx, prep, backend, x, ty, contexts...) end function value_and_pullback!( - f!::F, y, tx, backend::AbstractADType, x, ty, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, tx, backend::AbstractADType, x, ty, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(Val(true), f!, y, backend, x, ty, contexts...) return value_and_pullback!(f!, y, tx, prep, backend, x, ty, contexts...) end @@ -198,7 +198,7 @@ end Compute the pullback of the function `f` at point `x` with a tuple of tangents `ty`. -$(docstring_preparation_hint("pullback"; same_point=true)) +$(docstring_preparation_hint("pullback"; same_point = true)) !!! tip @@ -206,15 +206,15 @@ $(docstring_preparation_hint("pullback"; same_point=true)) This function could have been named `vjp`. """ function pullback( - f::F, backend::AbstractADType, x, ty, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, ty, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(Val(true), f, backend, x, ty, contexts...) return pullback(f, prep, backend, x, ty, contexts...) end function pullback( - f!::F, y, backend::AbstractADType, x, ty, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, backend::AbstractADType, x, ty, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(Val(true), f!, y, backend, x, ty, contexts...) return pullback(f!, y, prep, backend, x, ty, contexts...) end @@ -225,7 +225,7 @@ end Compute the pullback of the function `f` at point `x` with a tuple of tangents `ty`, overwriting `dx`. -$(docstring_preparation_hint("pullback"; same_point=true)) +$(docstring_preparation_hint("pullback"; same_point = true)) !!! tip @@ -233,57 +233,57 @@ $(docstring_preparation_hint("pullback"; same_point=true)) This function could have been named `vjp!`. """ function pullback!( - f::F, tx, backend::AbstractADType, x, ty, contexts::Vararg{Context,C} -) where {F,C} + f::F, tx, backend::AbstractADType, x, ty, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(Val(true), f, backend, x, ty, contexts...) return pullback!(f, tx, prep, backend, x, ty, contexts...) end function pullback!( - f!::F, y, tx, backend::AbstractADType, x, ty, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, tx, backend::AbstractADType, x, ty, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pullback_nokwarg(Val(true), f!, y, backend, x, ty, contexts...) return pullback!(f!, y, tx, prep, backend, x, ty, contexts...) end ## Preparation -struct PushforwardPullbackPrep{SIG,E} <: PullbackPrep{SIG} +struct PushforwardPullbackPrep{SIG, E} <: PullbackPrep{SIG} _sig::Val{SIG} pushforward_prep::E end function prepare_pullback_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context, C} + ) where {F, C} return _prepare_pullback_aux( strict, pullback_performance(backend), f, backend, x, ty, contexts... ) end function prepare_pullback_nokwarg( - strict::Val, - f!::F, - y, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} return _prepare_pullback_aux( strict, pullback_performance(backend), f!, y, backend, x, ty, contexts... ) end function _prepare_pullback_aux( - strict::Val, - ::PullbackSlow, - f::F, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + ::PullbackSlow, + f::F, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, ty, contexts...; strict) dx = if x isa Number oneunit(x) @@ -297,15 +297,15 @@ function _prepare_pullback_aux( end function _prepare_pullback_aux( - strict::Val, - ::PullbackSlow, - f!::F, - y, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + ::PullbackSlow, + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) dx = if x isa Number oneunit(x) @@ -321,26 +321,26 @@ end ## One argument function _pullback_via_pushforward( - f::F, - pushforward_prep::PushforwardPrep, - backend::AbstractADType, - x::Real, - dy, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + pushforward_prep::PushforwardPrep, + backend::AbstractADType, + x::Real, + dy, + contexts::Vararg{Context, C}, + ) where {F, C} a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) dx = dot(a, dy) return dx end function _pullback_via_pushforward( - f::F, - pushforward_prep::PushforwardPrep, - backend::AbstractADType, - x::Complex, - dy, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + pushforward_prep::PushforwardPrep, + backend::AbstractADType, + x::Complex, + dy, + contexts::Vararg{Context, C}, + ) where {F, C} a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)) dx = real(dot(a, dy)) + im * real(dot(b, dy)) @@ -348,13 +348,13 @@ function _pullback_via_pushforward( end function _pullback_via_pushforward( - f::F, - pushforward_prep::PushforwardPrep, - backend::AbstractADType, - x::AbstractArray{<:Real}, - dy, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + pushforward_prep::PushforwardPrep, + backend::AbstractADType, + x::AbstractArray{<:Real}, + dy, + contexts::Vararg{Context, C}, + ) where {F, C} dx = map(CartesianIndices(x)) do j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) dot(a, dy) @@ -363,13 +363,13 @@ function _pullback_via_pushforward( end function _pullback_via_pushforward( - f::F, - pushforward_prep::PushforwardPrep, - backend::AbstractADType, - x::AbstractArray{<:Complex}, - dy, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + pushforward_prep::PushforwardPrep, + backend::AbstractADType, + x::AbstractArray{<:Complex}, + dy, + contexts::Vararg{Context, C}, + ) where {F, C} dx = map(CartesianIndices(x)) do j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) b = only( @@ -381,13 +381,13 @@ function _pullback_via_pushforward( end function value_and_pullback( - f::F, - prep::PushforwardPullbackPrep, - backend::AbstractADType, - x, - ty::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f::F, + prep::PushforwardPullbackPrep, + backend::AbstractADType, + x, + ty::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep y = f(x, map(unwrap, contexts)...) @@ -399,14 +399,14 @@ function value_and_pullback( end function value_and_pullback!( - f::F, - tx::NTuple, - prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + tx::NTuple, + prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, ty, contexts...) y, new_tx = value_and_pullback(f, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) @@ -414,26 +414,26 @@ function value_and_pullback!( end function pullback( - f::F, - prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback(f, prep, backend, x, ty, contexts...)[2] end function pullback!( - f::F, - tx::NTuple, - prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + tx::NTuple, + prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2] end @@ -441,28 +441,28 @@ end ## Two arguments function _pullback_via_pushforward( - f!::F, - y, - pushforward_prep::PushforwardPrep, - backend::AbstractADType, - x::Real, - dy, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + pushforward_prep::PushforwardPrep, + backend::AbstractADType, + x::Real, + dy, + contexts::Vararg{Context, C}, + ) where {F, C} a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) dx = dot(a, dy) return dx end function _pullback_via_pushforward( - f!::F, - y, - pushforward_prep::PushforwardPrep, - backend::AbstractADType, - x::Complex, - dy, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + pushforward_prep::PushforwardPrep, + backend::AbstractADType, + x::Complex, + dy, + contexts::Vararg{Context, C}, + ) where {F, C} a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) b = only( pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...) @@ -472,14 +472,14 @@ function _pullback_via_pushforward( end function _pullback_via_pushforward( - f!::F, - y, - pushforward_prep::PushforwardPrep, - backend::AbstractADType, - x::AbstractArray{<:Real}, - dy, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + pushforward_prep::PushforwardPrep, + backend::AbstractADType, + x::AbstractArray{<:Real}, + dy, + contexts::Vararg{Context, C}, + ) where {F, C} dx = map(CartesianIndices(x)) do j # preserve shape a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) dot(a, dy) @@ -488,14 +488,14 @@ function _pullback_via_pushforward( end function _pullback_via_pushforward( - f!::F, - y, - pushforward_prep::PushforwardPrep, - backend::AbstractADType, - x::AbstractArray{<:Complex}, - dy, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + pushforward_prep::PushforwardPrep, + backend::AbstractADType, + x::AbstractArray{<:Complex}, + dy, + contexts::Vararg{Context, C}, + ) where {F, C} dx = map(CartesianIndices(x)) do j # preserve shape a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) b = only( @@ -509,14 +509,14 @@ function _pullback_via_pushforward( end function value_and_pullback( - f!::F, - y, - prep::PushforwardPullbackPrep, - backend::AbstractADType, - x, - ty::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f!::F, + y, + prep::PushforwardPullbackPrep, + backend::AbstractADType, + x, + ty::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f!, y, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep tx = ntuple( @@ -530,15 +530,15 @@ function value_and_pullback( end function value_and_pullback!( - f!::F, - y, - tx::NTuple, - prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + tx::NTuple, + prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, ty, contexts...) y, new_tx = value_and_pullback(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) @@ -546,28 +546,28 @@ function value_and_pullback!( end function pullback( - f!::F, - y, - prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback(f!, y, prep, backend, x, ty, contexts...)[2] end function pullback!( - f!::F, - y, - tx::NTuple, - prep::PullbackPrep, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + tx::NTuple, + prep::PullbackPrep, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback!(f!, y, tx, prep, backend, x, ty, contexts...)[2] end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index c244e3289..d3922bc0a 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -4,28 +4,28 @@ prepare_pushforward(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep prepare_pushforward(f!, y, backend, x, tx, [contexts...]; strict=Val(true)) -> prep -$(docstring_prepare("pushforward"; inplace=true)) +$(docstring_prepare("pushforward"; inplace = true)) """ function prepare_pushforward( - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_pushforward_nokwarg(strict, f, backend, x, tx, contexts...) end function prepare_pushforward( - f!::F, - y, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_pushforward_nokwarg(strict, f!, y, backend, x, tx, contexts...) end @@ -36,26 +36,26 @@ end $(docstring_prepare!("pushforward")) """ function prepare!_pushforward( - f::F, - old_prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + old_prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, old_prep, backend, x, tx, contexts...) return prepare_pushforward_nokwarg(is_strict(old_prep), f, backend, x, tx, contexts...) end function prepare!_pushforward( - f!::F, - y, - old_prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + old_prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, old_prep, backend, x, tx, contexts...) return prepare_pushforward_nokwarg( is_strict(old_prep), f!, y, backend, x, tx, contexts... @@ -66,74 +66,74 @@ end prepare_pushforward_same_point(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep_same prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]; strict=Val(true)) -> prep_same -$(docstring_prepare("pushforward"; samepoint=true, inplace=true)) +$(docstring_prepare("pushforward"; samepoint = true, inplace = true)) """ function prepare_pushforward_same_point( - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_pushforward_same_point_nokwarg(strict, f, backend, x, tx, contexts...) end function prepare_pushforward_same_point( - f!::F, - y, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_pushforward_same_point_nokwarg( strict, f!, y, backend, x, tx, contexts... ) end function prepare_pushforward_same_point_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(strict, f, backend, x, tx, contexts...) return prepare_pushforward_same_point(f, prep, backend, x, tx, contexts...) end function prepare_pushforward_same_point_nokwarg( - strict::Val, - f!::F, - y, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(strict, f!, y, backend, x, tx, contexts...) return prepare_pushforward_same_point(f!, y, prep, backend, x, tx, contexts...) end function prepare_pushforward_same_point( - f::F, - prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return prep end function prepare_pushforward_same_point( - f!::F, - y, - prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, tx, contexts...) return prep end @@ -144,7 +144,7 @@ end Compute the value and the pushforward of the function `f` at point `x` with a tuple of tangents `tx`. -$(docstring_preparation_hint("pushforward"; same_point=true)) +$(docstring_preparation_hint("pushforward"; same_point = true)) !!! tip @@ -156,15 +156,15 @@ $(docstring_preparation_hint("pushforward"; same_point=true)) Required primitive for forward mode backends. """ function value_and_pushforward( - f::F, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(Val(true), f, backend, x, tx, contexts...) return value_and_pushforward(f, prep, backend, x, tx, contexts...) end function value_and_pushforward( - f!::F, y, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(Val(true), f!, y, backend, x, tx, contexts...) return value_and_pushforward(f!, y, prep, backend, x, tx, contexts...) end @@ -175,7 +175,7 @@ end Compute the value and the pushforward of the function `f` at point `x` with a tuple of tangents `tx`, overwriting `ty`. -$(docstring_preparation_hint("pushforward"; same_point=true)) +$(docstring_preparation_hint("pushforward"; same_point = true)) !!! tip @@ -183,15 +183,15 @@ $(docstring_preparation_hint("pushforward"; same_point=true)) This function could have been named `value_and_jvp!`. """ function value_and_pushforward!( - f::F, ty, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f::F, ty, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(Val(true), f, backend, x, tx, contexts...) return value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...) end function value_and_pushforward!( - f!::F, y, ty, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, ty, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(Val(true), f!, y, backend, x, tx, contexts...) return value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) end @@ -202,7 +202,7 @@ end Compute the pushforward of the function `f` at point `x` with a tuple of tangents `tx`. -$(docstring_preparation_hint("pushforward"; same_point=true)) +$(docstring_preparation_hint("pushforward"; same_point = true)) !!! tip @@ -210,15 +210,15 @@ $(docstring_preparation_hint("pushforward"; same_point=true)) This function could have been named `jvp`. """ function pushforward( - f::F, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(Val(true), f, backend, x, tx, contexts...) return pushforward(f, prep, backend, x, tx, contexts...) end function pushforward( - f!::F, y, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(Val(true), f!, y, backend, x, tx, contexts...) return pushforward(f!, y, prep, backend, x, tx, contexts...) end @@ -229,7 +229,7 @@ end Compute the pushforward of the function `f` at point `x` with a tuple of tangents `tx`, overwriting `ty`. -$(docstring_preparation_hint("pushforward"; same_point=true)) +$(docstring_preparation_hint("pushforward"; same_point = true)) !!! tip @@ -237,57 +237,57 @@ $(docstring_preparation_hint("pushforward"; same_point=true)) This function could have been named `jvp!`. """ function pushforward!( - f::F, ty, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f::F, ty, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(Val(true), f, backend, x, tx, contexts...) return pushforward!(f, ty, prep, backend, x, tx, contexts...) end function pushforward!( - f!::F, y, ty, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, ty, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_pushforward_nokwarg(Val(true), f!, y, backend, x, tx, contexts...) return pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) end ## Preparation -struct PullbackPushforwardPrep{SIG,E} <: PushforwardPrep{SIG} +struct PullbackPushforwardPrep{SIG, E} <: PushforwardPrep{SIG} _sig::Val{SIG} pullback_prep::E end function prepare_pushforward_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context, C} + ) where {F, C} return _prepare_pushforward_aux( strict, pushforward_performance(backend), f, backend, x, tx, contexts... ) end function prepare_pushforward_nokwarg( - strict::Val, - f!::F, - y, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} return _prepare_pushforward_aux( strict, pushforward_performance(backend), f!, y, backend, x, tx, contexts... ) end function _prepare_pushforward_aux( - strict::Val, - ::PushforwardSlow, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + ::PushforwardSlow, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) dy = if y isa Number @@ -300,15 +300,15 @@ function _prepare_pushforward_aux( end function _prepare_pushforward_aux( - strict::Val, - ::PushforwardSlow, - f!::F, - y, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + ::PushforwardSlow, + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) dy = basis(y) pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...) @@ -318,28 +318,28 @@ end ## One argument function _pushforward_via_pullback( - y::Number, - f::F, - pullback_prep::PullbackPrep, - backend::AbstractADType, - x, - dx, - contexts::Vararg{Context,C}, -) where {F,C} + y::Number, + f::F, + pullback_prep::PullbackPrep, + backend::AbstractADType, + x, + dx, + contexts::Vararg{Context, C}, + ) where {F, C} a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...)) dy = dot(a, dx) return dy end function _pushforward_via_pullback( - y::Complex, - f::F, - pullback_prep::PullbackPrep, - backend::AbstractADType, - x, - dx, - contexts::Vararg{Context,C}, -) where {F,C} + y::Complex, + f::F, + pullback_prep::PullbackPrep, + backend::AbstractADType, + x, + dx, + contexts::Vararg{Context, C}, + ) where {F, C} a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...)) b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y),), contexts...)) dy = real(dot(a, dx)) + im * real(dot(b, dx)) @@ -347,14 +347,14 @@ function _pushforward_via_pullback( end function _pushforward_via_pullback( - y::AbstractArray{<:Real}, - f::F, - pullback_prep::PullbackPrep, - backend::AbstractADType, - x, - dx, - contexts::Vararg{Context,C}, -) where {F,C} + y::AbstractArray{<:Real}, + f::F, + pullback_prep::PullbackPrep, + backend::AbstractADType, + x, + dx, + contexts::Vararg{Context, C}, + ) where {F, C} dy = map(CartesianIndices(y)) do i a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)) dot(a, dx) @@ -363,14 +363,14 @@ function _pushforward_via_pullback( end function _pushforward_via_pullback( - y::AbstractArray{<:Complex}, - f::F, - pullback_prep::PullbackPrep, - backend::AbstractADType, - x, - dx, - contexts::Vararg{Context,C}, -) where {F,C} + y::AbstractArray{<:Complex}, + f::F, + pullback_prep::PullbackPrep, + backend::AbstractADType, + x, + dx, + contexts::Vararg{Context, C}, + ) where {F, C} dy = map(CartesianIndices(y)) do i a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)) b = only(pullback(f, pullback_prep, backend, x, (im * basis(y, i),), contexts...)) @@ -380,13 +380,13 @@ function _pushforward_via_pullback( end function value_and_pushforward( - f::F, - prep::PullbackPushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f::F, + prep::PullbackPushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f, prep, backend, x, tx, contexts...) (; pullback_prep) = prep y = f(x, map(unwrap, contexts)...) @@ -398,14 +398,14 @@ function value_and_pushforward( end function value_and_pushforward!( - f::F, - ty::NTuple, - prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + ty::NTuple, + prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) y, new_ty = value_and_pushforward(f, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) @@ -413,26 +413,26 @@ function value_and_pushforward!( end function pushforward( - f::F, - prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward(f, prep, backend, x, tx, contexts...)[2] end function pushforward!( - f::F, - ty::NTuple, - prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + ty::NTuple, + prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...)[2] end @@ -440,14 +440,14 @@ end ## Two arguments function _pushforward_via_pullback( - f!::F, - y::AbstractArray{<:Real}, - pullback_prep::PullbackPrep, - backend::AbstractADType, - x, - dx, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y::AbstractArray{<:Real}, + pullback_prep::PullbackPrep, + backend::AbstractADType, + x, + dx, + contexts::Vararg{Context, C}, + ) where {F, C} dy = map(CartesianIndices(y)) do i # preserve shape a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) dot(a, dx) @@ -456,14 +456,14 @@ function _pushforward_via_pullback( end function _pushforward_via_pullback( - f!::F, - y::AbstractArray{<:Complex}, - pullback_prep::PullbackPrep, - backend::AbstractADType, - x, - dx, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y::AbstractArray{<:Complex}, + pullback_prep::PullbackPrep, + backend::AbstractADType, + x, + dx, + contexts::Vararg{Context, C}, + ) where {F, C} dy = map(CartesianIndices(y)) do i # preserve shape a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) b = only( @@ -475,19 +475,19 @@ function _pushforward_via_pullback( end function value_and_pushforward( - f!::F, - y, - prep::PullbackPushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f!::F, + y, + prep::PullbackPushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f!, y, prep, backend, x, tx, contexts...) (; pullback_prep) = prep ty = ntuple( b -> - _pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...), + _pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...), Val(B), ) f!(y, x, map(unwrap, contexts)...) @@ -495,15 +495,15 @@ function value_and_pushforward( end function value_and_pushforward!( - f!::F, - y, - ty::NTuple, - prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + ty::NTuple, + prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, tx, contexts...) y, new_ty = value_and_pushforward(f!, y, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) @@ -511,28 +511,28 @@ function value_and_pushforward!( end function pushforward( - f!::F, - y, - prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2] end function pushforward!( - f!::F, - y, - ty::NTuple, - prep::PushforwardPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + ty::NTuple, + prep::PushforwardPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)[2] end @@ -540,13 +540,13 @@ end ## Shuffled function shuffled_single_pushforward( - x, - f::F, - backend::AbstractADType, - dx, - rewrap::Rewrap{C}, - unannotated_contexts::Vararg{Any,C}, -) where {F,C} + x, + f::F, + backend::AbstractADType, + dx, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any, C}, + ) where {F, C} ty = pushforward(f, backend, x, (dx,), rewrap(unannotated_contexts...)...) return only(ty) end diff --git a/DifferentiationInterface/src/init.jl b/DifferentiationInterface/src/init.jl index 71732eb19..78c910d2e 100644 --- a/DifferentiationInterface/src/init.jl +++ b/DifferentiationInterface/src/init.jl @@ -1,5 +1,5 @@ function __init__() - Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs + return Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs if exc.f in (_prepare_pushforward_aux, _prepare_pullback_aux) B = first(T for T in argtypes if T <: AbstractADType) packages = required_packages(B) @@ -13,14 +13,14 @@ function __init__() printstyled( io, "\n\nThe autodiff backend you chose requires a package which may not be loaded. Please run the following command and try again:"; - bold=true, + bold = true, ) - printstyled(io, "\n\n\t$import_statement"; color=:cyan, bold=true) + printstyled(io, "\n\n\t$import_statement"; color = :cyan, bold = true) else printstyled( io, "\n\nThe autodiff backend you chose may not be compatible with the operation you want to perform. Please refer to the documentation of DifferentiationInterface.jl and open an issue if necessary."; - bold=true, + bold = true, ) end end diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index eeac288d9..f0c2ecf38 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -69,7 +69,7 @@ julia> Zygote.gradient(alg, [3.0, 5.0])[1] 70.0 ``` """ -struct DifferentiateWith{F,B<:AbstractADType} +struct DifferentiateWith{F, B <: AbstractADType} f::F backend::B end @@ -82,9 +82,9 @@ function Base.show(io::IO, dw::DifferentiateWith) io, DifferentiateWith, "(", - repr(f; context=io), + repr(f; context = io), ", ", - repr(backend; context=io), + repr(backend; context = io), ")", ) end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index db8e7b276..5bbd3490a 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -21,8 +21,8 @@ function overloaded_input(::typeof(pushforward), f, backend::FromPrimitive, x, t end function overloaded_input( - ::typeof(pushforward), f!, y, backend::FromPrimitive, x, tx::NTuple -) + ::typeof(pushforward), f!, y, backend::FromPrimitive, x, tx::NTuple + ) return overloaded_input(pushforward, f!, y, backend.backend, x, tx) end @@ -36,39 +36,39 @@ Wrapper which forces a given backend to act as a forward-mode backend, using onl This can be useful to circumvent high-level operators when they have impractical limitations. For instance, ForwardDiff.jl's `jacobian` does not support GPU arrays but its `pushforward` does, so `AutoForwardFromPrimitive(AutoForwardDiff())` has a GPU-friendly `jacobian`. """ -struct AutoForwardFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace} +struct AutoForwardFromPrimitive{inplace, B <: AbstractADType} <: FromPrimitive{inplace} backend::B end function AutoForwardFromPrimitive( - backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend)) -) - return AutoForwardFromPrimitive{inplace,typeof(backend)}(backend) + backend::AbstractADType; inplace::Bool = Bool(inplace_support(backend)) + ) + return AutoForwardFromPrimitive{inplace, typeof(backend)}(backend) end ADTypes.mode(::AutoForwardFromPrimitive) = ADTypes.ForwardMode() function threshold_batchsize( - backend::AutoForwardFromPrimitive{inplace}, dimension::Integer -) where {inplace} + backend::AutoForwardFromPrimitive{inplace}, dimension::Integer + ) where {inplace} return AutoForwardFromPrimitive( threshold_batchsize(backend.backend, dimension); inplace ) end -struct FromPrimitivePushforwardPrep{SIG,E<:PushforwardPrep} <: PushforwardPrep{SIG} +struct FromPrimitivePushforwardPrep{SIG, E <: PushforwardPrep} <: PushforwardPrep{SIG} _sig::Val{SIG} pushforward_prep::E end function prepare_pushforward_nokwarg( - strict::Val, - f::F, - backend::AutoForwardFromPrimitive, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f::F, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, tx, contexts...; strict) primitive_prep = prepare_pushforward_nokwarg( strict, f, backend.backend, x, tx, contexts... @@ -77,14 +77,14 @@ function prepare_pushforward_nokwarg( end function prepare_pushforward_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoForwardFromPrimitive, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) primitive_prep = prepare_pushforward_nokwarg( strict, f!, y, backend.backend, x, tx, contexts... @@ -93,13 +93,13 @@ function prepare_pushforward_nokwarg( end function value_and_pushforward( - f::F, - prep::FromPrimitivePushforwardPrep, - backend::AutoForwardFromPrimitive, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::FromPrimitivePushforwardPrep, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward( f, prep.pushforward_prep, backend.backend, x, tx, contexts... @@ -107,14 +107,14 @@ function value_and_pushforward( end function value_and_pushforward( - f!::F, - y, - prep::FromPrimitivePushforwardPrep, - backend::AutoForwardFromPrimitive, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + prep::FromPrimitivePushforwardPrep, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward( f!, y, prep.pushforward_prep, backend.backend, x, tx, contexts... @@ -122,14 +122,14 @@ function value_and_pushforward( end function value_and_pushforward!( - f::F, - ty::NTuple, - prep::FromPrimitivePushforwardPrep, - backend::AutoForwardFromPrimitive, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + ty::NTuple, + prep::FromPrimitivePushforwardPrep, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward!( f, ty, prep.pushforward_prep, backend.backend, x, tx, contexts... @@ -137,15 +137,15 @@ function value_and_pushforward!( end function value_and_pushforward!( - f!::F, - y, - ty::NTuple, - prep::FromPrimitivePushforwardPrep, - backend::AutoForwardFromPrimitive, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + ty::NTuple, + prep::FromPrimitivePushforwardPrep, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward!( f!, y, ty, prep.pushforward_prep, backend.backend, x, tx, contexts... @@ -157,39 +157,39 @@ end Wrapper which forces a given backend to act as a reverse-mode backend, using only its native `value_and_pullback` implementation and rebuilding the rest from scratch. """ -struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace} +struct AutoReverseFromPrimitive{inplace, B <: AbstractADType} <: FromPrimitive{inplace} backend::B end function AutoReverseFromPrimitive( - backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend)) -) - return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend) + backend::AbstractADType; inplace::Bool = Bool(inplace_support(backend)) + ) + return AutoReverseFromPrimitive{inplace, typeof(backend)}(backend) end ADTypes.mode(::AutoReverseFromPrimitive) = ADTypes.ReverseMode() function threshold_batchsize( - backend::AutoReverseFromPrimitive{inplace}, dimension::Integer -) where {inplace} + backend::AutoReverseFromPrimitive{inplace}, dimension::Integer + ) where {inplace} return AutoReverseFromPrimitive( threshold_batchsize(backend.backend, dimension); inplace ) end -struct FromPrimitivePullbackPrep{SIG,E<:PullbackPrep} <: PullbackPrep{SIG} +struct FromPrimitivePullbackPrep{SIG, E <: PullbackPrep} <: PullbackPrep{SIG} _sig::Val{SIG} pullback_prep::E end function prepare_pullback_nokwarg( - strict::Val, - f::F, - backend::AutoReverseFromPrimitive, - x, - ty::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f::F, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, ty, contexts...; strict) primitive_prep = prepare_pullback_nokwarg( strict, f, backend.backend, x, ty, contexts... @@ -198,14 +198,14 @@ function prepare_pullback_nokwarg( end function prepare_pullback_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoReverseFromPrimitive, - x, - ty::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) primitive_prep = prepare_pullback_nokwarg( strict, f!, y, backend.backend, x, ty, contexts... @@ -214,26 +214,26 @@ function prepare_pullback_nokwarg( end function value_and_pullback( - f::F, - prep::FromPrimitivePullbackPrep, - backend::AutoReverseFromPrimitive, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::FromPrimitivePullbackPrep, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback(f, prep.pullback_prep, backend.backend, x, ty, contexts...) end function value_and_pullback( - f!::F, - y, - prep::FromPrimitivePullbackPrep, - backend::AutoReverseFromPrimitive, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + prep::FromPrimitivePullbackPrep, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback( f!, y, prep.pullback_prep, backend.backend, x, ty, contexts... @@ -241,14 +241,14 @@ function value_and_pullback( end function value_and_pullback!( - f::F, - tx::NTuple, - prep::FromPrimitivePullbackPrep, - backend::AutoReverseFromPrimitive, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + tx::NTuple, + prep::FromPrimitivePullbackPrep, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback!( f, tx, prep.pullback_prep, backend.backend, x, ty, contexts... @@ -256,15 +256,15 @@ function value_and_pullback!( end function value_and_pullback!( - f!::F, - y, - tx::NTuple, - prep::FromPrimitivePullbackPrep, - backend::AutoReverseFromPrimitive, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + tx::NTuple, + prep::FromPrimitivePullbackPrep, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback!( f!, y, tx, prep.pullback_prep, backend.backend, x, ty, contexts... diff --git a/DifferentiationInterface/src/misc/overloading.jl b/DifferentiationInterface/src/misc/overloading.jl index 20a38aa5f..605791271 100644 --- a/DifferentiationInterface/src/misc/overloading.jl +++ b/DifferentiationInterface/src/misc/overloading.jl @@ -14,7 +14,7 @@ function overloaded_input(::typeof(pushforward), f, backend::AbstractADType, x, end function overloaded_input( - ::typeof(pushforward), f!, y, backend::AbstractADType, x, tx::NTuple -) + ::typeof(pushforward), f!, y, backend::AbstractADType, x, tx::NTuple + ) throw(ArgumentError("Overloaded input not defined")) end diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index 8481cad97..78a8c7697 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -7,12 +7,12 @@ Forward mode backend based on the finite difference `(f(x + ε) - f(x)) / ε`, w AutoSimpleFiniteDiff(ε=1e-5; chunksize=nothing) """ -struct AutoSimpleFiniteDiff{chunksize,T<:Real} <: AbstractADType +struct AutoSimpleFiniteDiff{chunksize, T <: Real} <: AbstractADType ε::T end -function AutoSimpleFiniteDiff(ε=1e-5; chunksize=nothing) - return AutoSimpleFiniteDiff{chunksize,typeof(ε)}(ε) +function AutoSimpleFiniteDiff(ε = 1.0e-5; chunksize = nothing) + return AutoSimpleFiniteDiff{chunksize, typeof(ε)}(ε) end ADTypes.mode(::AutoSimpleFiniteDiff) = ForwardMode() @@ -30,45 +30,45 @@ function pick_batchsize(::AutoSimpleFiniteDiff{chunksize}, N::Integer) where {ch end function threshold_batchsize( - backend::AutoSimpleFiniteDiff{chunksize1}, chunksize2::Integer -) where {chunksize1} + backend::AutoSimpleFiniteDiff{chunksize1}, chunksize2::Integer + ) where {chunksize1} chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2) return AutoSimpleFiniteDiff(backend.ε; chunksize) end function prepare_pushforward_nokwarg( - strict::Val, - f::F, - backend::AutoSimpleFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f::F, + backend::AutoSimpleFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) end function prepare_pushforward_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoSimpleFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AutoSimpleFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) end function value_and_pushforward( - f::F, - prep::NoPushforwardPrep, - backend::AutoSimpleFiniteDiff, - x, - tx::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f::F, + prep::NoPushforwardPrep, + backend::AutoSimpleFiniteDiff, + x, + tx::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f, prep, backend, x, tx, contexts...) ε = eltype(x)(backend.ε) y = f(x, map(unwrap, contexts)...) @@ -81,14 +81,14 @@ function value_and_pushforward( end function value_and_pushforward( - f!::F, - y, - prep::NoPushforwardPrep, - backend::AutoSimpleFiniteDiff, - x, - tx::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f!::F, + y, + prep::NoPushforwardPrep, + backend::AutoSimpleFiniteDiff, + x, + tx::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f!, y, prep, backend, x, tx, contexts...) ε = eltype(x)(backend.ε) ty = map(tx) do dx diff --git a/DifferentiationInterface/src/misc/sparsity_detector.jl b/DifferentiationInterface/src/misc/sparsity_detector.jl index 8d3823e1d..0ddf729c6 100644 --- a/DifferentiationInterface/src/misc/sparsity_detector.jl +++ b/DifferentiationInterface/src/misc/sparsity_detector.jl @@ -69,7 +69,7 @@ ADTypes.jacobian_sparsity(x -> [prod(x)], [0, 1], detector) 1 ⋅ ``` """ -struct DenseSparsityDetector{method,B} <: ADTypes.AbstractSparsityDetector +struct DenseSparsityDetector{method, B} <: ADTypes.AbstractSparsityDetector backend::B atol::Float64 end @@ -80,20 +80,20 @@ function Base.show(io::IO, detector::DenseSparsityDetector{method}) where {metho io, DenseSparsityDetector, "(", - repr(backend; context=io), + repr(backend; context = io), "; atol=$atol, method=", - repr(method; context=io), + repr(method; context = io), ")", ) end function DenseSparsityDetector( - backend::AbstractADType; atol::Float64, method::Symbol=:iterative -) + backend::AbstractADType; atol::Float64, method::Symbol = :iterative + ) if !(method in (:iterative, :direct)) throw( ArgumentError("The keyword `method` must be either `:iterative` or `:direct`.") ) end - return DenseSparsityDetector{method,typeof(backend)}(backend, atol) + return DenseSparsityDetector{method, typeof(backend)}(backend, atol) end diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index e74449f3c..c91d1b204 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -21,33 +21,33 @@ check_available(::AutoZeroForward) = true inplace_support(::AutoZeroForward) = InPlaceSupported() function prepare_pushforward_nokwarg( - strict::Val, f::F, backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f::F, backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) end function prepare_pushforward_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoZeroForward, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AutoZeroForward, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) end function value_and_pushforward( - f::F, - prep::NoPushforwardPrep, - backend::AutoZeroForward, - x, - tx::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f::F, + prep::NoPushforwardPrep, + backend::AutoZeroForward, + x, + tx::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f, prep, backend, x, tx, contexts...) y = f(x, map(unwrap, contexts)...) ty = map(ReturnZero(y), tx) @@ -55,14 +55,14 @@ function value_and_pushforward( end function value_and_pushforward( - f!::F, - y, - prep::NoPushforwardPrep, - backend::AutoZeroForward, - x, - tx::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f!::F, + y, + prep::NoPushforwardPrep, + backend::AutoZeroForward, + x, + tx::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(unwrap, contexts)...) ty = map(ReturnZero(y), tx) @@ -70,14 +70,14 @@ function value_and_pushforward( end function value_and_pushforward!( - f::F, - ty::NTuple, - prep::NoPushforwardPrep, - backend::AutoZeroForward, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + ty::NTuple, + prep::NoPushforwardPrep, + backend::AutoZeroForward, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) y = f(x, map(unwrap, contexts)...) for b in eachindex(ty) @@ -87,15 +87,15 @@ function value_and_pushforward!( end function value_and_pushforward!( - f!::F, - y, - ty::NTuple, - prep::NoPushforwardPrep, - backend::AutoZeroForward, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + ty::NTuple, + prep::NoPushforwardPrep, + backend::AutoZeroForward, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(unwrap, contexts)...) for b in eachindex(ty) @@ -119,33 +119,33 @@ check_available(::AutoZeroReverse) = true inplace_support(::AutoZeroReverse) = InPlaceSupported() function prepare_pullback_nokwarg( - strict::Val, f::F, backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f::F, backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, ty, contexts...; strict) return NoPullbackPrep(_sig) end function prepare_pullback_nokwarg( - strict::Val, - f!::F, - y, - backend::AutoZeroReverse, - x, - ty::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + f!::F, + y, + backend::AutoZeroReverse, + x, + ty::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) return NoPullbackPrep(_sig) end function value_and_pullback( - f::F, - prep::NoPullbackPrep, - backend::AutoZeroReverse, - x, - ty::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f::F, + prep::NoPullbackPrep, + backend::AutoZeroReverse, + x, + ty::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f, prep, backend, x, ty, contexts...) y = f(x, map(unwrap, contexts)...) tx = ntuple(ReturnZero(x), Val(B)) @@ -153,14 +153,14 @@ function value_and_pullback( end function value_and_pullback( - f!::F, - y, - prep::NoPullbackPrep, - backend::AutoZeroReverse, - x, - ty::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} + f!::F, + y, + prep::NoPullbackPrep, + backend::AutoZeroReverse, + x, + ty::NTuple{B}, + contexts::Vararg{Context, C}, + ) where {F, B, C} check_prep(f!, y, prep, backend, x, ty, contexts...) f!(y, x, map(unwrap, contexts)...) tx = ntuple(ReturnZero(x), Val(B)) @@ -168,14 +168,14 @@ function value_and_pullback( end function value_and_pullback!( - f::F, - tx::NTuple, - prep::NoPullbackPrep, - backend::AutoZeroReverse, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + tx::NTuple, + prep::NoPullbackPrep, + backend::AutoZeroReverse, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, ty, contexts...) y = f(x, map(unwrap, contexts)...) for b in eachindex(tx) @@ -185,15 +185,15 @@ function value_and_pullback!( end function value_and_pullback!( - f!::F, - y, - tx::NTuple, - prep::NoPullbackPrep, - backend::AutoZeroReverse, - x, - ty::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f!::F, + y, + tx::NTuple, + prep::NoPullbackPrep, + backend::AutoZeroReverse, + x, + ty::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f!, y, prep, backend, x, ty, contexts...) f!(y, x, map(unwrap, contexts)...) for b in eachindex(tx) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index a815fb08c..f8d74b4a4 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -6,8 +6,8 @@ $(docstring_prepare("hessian")) """ function prepare_hessian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C}; strict::Val = Val(true) + ) where {F, C} return prepare_hessian_nokwarg(strict, f, backend, x, contexts...) end @@ -17,8 +17,8 @@ end $(docstring_prepare!("hessian")) """ function prepare!_hessian( - f::F, old_prep::HessianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C}; -) where {F,C} + f::F, old_prep::HessianPrep, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} check_prep(f, old_prep, backend, x, contexts...) return prepare_hessian_nokwarg(is_strict(old_prep), f, backend, x, contexts...) end @@ -30,7 +30,7 @@ Compute the Hessian matrix of the function `f` at point `x`. $(docstring_preparation_hint("hessian")) """ -function hessian(f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}) where {F,C} +function hessian(f::F, backend::AbstractADType, x, contexts::Vararg{Context, C}) where {F, C} prep = prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return hessian(f, prep, backend, x, contexts...) end @@ -43,8 +43,8 @@ Compute the Hessian matrix of the function `f` at point `x`, overwriting `hess`. $(docstring_preparation_hint("hessian")) """ function hessian!( - f::F, hess, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, hess, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return hessian!(f, hess, prep, backend, x, contexts...) end @@ -57,8 +57,8 @@ Compute the value, gradient vector and Hessian matrix of the function `f` at poi $(docstring_preparation_hint("hessian")) """ function value_gradient_and_hessian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return value_gradient_and_hessian(f, prep, backend, x, contexts...) end @@ -71,8 +71,8 @@ Compute the value, gradient vector and Hessian matrix of the function `f` at poi $(docstring_preparation_hint("hessian")) """ function value_gradient_and_hessian!( - f::F, grad, hess, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, grad, hess, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return value_gradient_and_hessian!(f, grad, hess, prep, backend, x, contexts...) end @@ -80,14 +80,14 @@ end ## Preparation struct HVPGradientHessianPrep{ - SIG, - BS<:BatchSizeSettings, - S<:AbstractVector{<:NTuple}, - R<:AbstractVector{<:NTuple}, - SE<:NTuple, - E2<:HVPPrep, - E1<:GradientPrep, -} <: HessianPrep{SIG} + SIG, + BS <: BatchSizeSettings, + S <: AbstractVector{<:NTuple}, + R <: AbstractVector{<:NTuple}, + SE <: NTuple, + E2 <: HVPPrep, + E1 <: GradientPrep, + } <: HessianPrep{SIG} _sig::Val{SIG} batch_size_settings::BS batched_seeds::S @@ -98,8 +98,8 @@ struct HVPGradientHessianPrep{ end function prepare_hessian_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} # type-unstable batch_size_settings = pick_batchsize(outer(backend), x) # function barrier @@ -107,13 +107,13 @@ function prepare_hessian_nokwarg( end function _prepare_hessian_aux( - strict::Val, - batch_size_settings::BatchSizeSettings{B}, - f::F, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}; -) where {B,F,C} + strict::Val, + batch_size_settings::BatchSizeSettings{B}, + f::F, + backend::AbstractADType, + x, + contexts::Vararg{Context, C} + ) where {B, F, C} _sig = signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(x, ind) for ind in eachindex(x)] @@ -138,12 +138,12 @@ end ## One argument function hessian( - f::F, - prep::HVPGradientHessianPrep{SIG,<:BatchSizeSettings{B,true}}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,SIG,B,C} + f::F, + prep::HVPGradientHessianPrep{SIG, <:BatchSizeSettings{B, true}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, SIG, B, C} check_prep(f, prep, backend, x, contexts...) (; batched_seeds, hvp_prep) = prep dg_batch = hvp(f, hvp_prep, backend, x, only(batched_seeds), contexts...) @@ -152,12 +152,12 @@ function hessian( end function hessian( - f::F, - prep::HVPGradientHessianPrep{SIG,<:BatchSizeSettings{B,false,aligned}}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,SIG,B,aligned,C} + f::F, + prep::HVPGradientHessianPrep{SIG, <:BatchSizeSettings{B, false, aligned}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, SIG, B, aligned, C} check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, batched_seeds, seed_example, hvp_prep) = prep (; A, B_last) = batch_size_settings @@ -179,13 +179,13 @@ function hessian( end function hessian!( - f::F, - hess, - prep::HVPGradientHessianPrep{SIG,<:BatchSizeSettings{B}}, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,SIG,B,C} + f::F, + hess, + prep::HVPGradientHessianPrep{SIG, <:BatchSizeSettings{B}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, SIG, B, C} check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, batched_seeds, batched_results, seed_example, hvp_prep) = prep (; N) = batch_size_settings @@ -210,12 +210,12 @@ function hessian!( end function value_gradient_and_hessian( - f::F, - prep::HVPGradientHessianPrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::HVPGradientHessianPrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) y, grad = value_and_gradient(f, prep.gradient_prep, inner(backend), x, contexts...) hess = hessian(f, prep, backend, x, contexts...) @@ -223,14 +223,14 @@ function value_gradient_and_hessian( end function value_gradient_and_hessian!( - f::F, - grad, - hess, - prep::HVPGradientHessianPrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + grad, + hess, + prep::HVPGradientHessianPrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) y, _ = value_and_gradient!(f, grad, prep.gradient_prep, inner(backend), x, contexts...) hessian!(f, hess, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 8a85bf1c1..58afeadbf 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -6,13 +6,13 @@ $(docstring_prepare("hvp")) """ function prepare_hvp( - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_hvp_nokwarg(strict, f, backend, x, tx, contexts...) end @@ -22,13 +22,13 @@ end $(docstring_prepare("hvp")) """ function prepare!_hvp( - f::F, - old_prep::HVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + old_prep::HVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, old_prep, backend, x, tx, contexts...) return prepare_hvp_nokwarg(is_strict(old_prep), f, backend, x, tx, contexts...) end @@ -36,29 +36,29 @@ end """ prepare_hvp_same_point(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep_same -$(docstring_prepare("hvp"; samepoint=true)) +$(docstring_prepare("hvp"; samepoint = true)) """ function prepare_hvp_same_point( - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(true), -) where {F,C} + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}; + strict::Val = Val(true), + ) where {F, C} return prepare_hvp_same_point_nokwarg(strict, f, backend, x, tx, contexts...) end function prepare_hvp_same_point_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_hvp_nokwarg(strict, f, backend, x, tx, contexts...) return prepare_hvp_same_point(f, prep, backend, x, tx, contexts...) end function prepare_hvp_same_point( - f::F, prep::HVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C} -) where {F,C} + f::F, prep::HVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context, C} + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return prep end @@ -68,9 +68,9 @@ end Compute the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`. -$(docstring_preparation_hint("hvp"; same_point=true)) +$(docstring_preparation_hint("hvp"; same_point = true)) """ -function hvp(f::F, backend::AbstractADType, x, tx, contexts::Vararg{Context,C}) where {F,C} +function hvp(f::F, backend::AbstractADType, x, tx, contexts::Vararg{Context, C}) where {F, C} prep = prepare_hvp_nokwarg(Val(true), f, backend, x, tx, contexts...) return hvp(f, prep, backend, x, tx, contexts...) end @@ -80,11 +80,11 @@ end Compute the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`, overwriting `tg`. -$(docstring_preparation_hint("hvp"; same_point=true)) +$(docstring_preparation_hint("hvp"; same_point = true)) """ function hvp!( - f::F, tg, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f::F, tg, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_hvp_nokwarg(Val(true), f, backend, x, tx, contexts...) return hvp!(f, tg, prep, backend, x, tx, contexts...) end @@ -94,11 +94,11 @@ end Compute the gradient and the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`. -$(docstring_preparation_hint("hvp"; same_point=true)) +$(docstring_preparation_hint("hvp"; same_point = true)) """ function gradient_and_hvp( - f::F, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_hvp_nokwarg(Val(true), f, backend, x, tx, contexts...) return gradient_and_hvp(f, prep, backend, x, tx, contexts...) end @@ -108,11 +108,11 @@ end Compute the gradient and the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`, overwriting `grad` and `tg`. -$(docstring_preparation_hint("hvp"; same_point=true)) +$(docstring_preparation_hint("hvp"; same_point = true)) """ function gradient_and_hvp!( - f::F, grad, tg, backend::AbstractADType, x, tx, contexts::Vararg{Context,C} -) where {F,C} + f::F, grad, tg, backend::AbstractADType, x, tx, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_hvp_nokwarg(Val(true), f, backend, x, tx, contexts...) return gradient_and_hvp!(f, grad, tg, prep, backend, x, tx, contexts...) end @@ -120,8 +120,8 @@ end ## Preparation function prepare_hvp_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context, C} + ) where {F, C} return _prepare_hvp_aux( strict, hvp_mode(backend), @@ -130,13 +130,13 @@ function prepare_hvp_nokwarg( backend, x, tx, - contexts...; + contexts... ) end ## Forward over anything -struct ForwardOverAnythingHVPPrep{SIG,G,GO,GI,PO,PI} <: HVPPrep{SIG} +struct ForwardOverAnythingHVPPrep{SIG, G, GO, GI, PO, PI} <: HVPPrep{SIG} # pushforward of many pushforwards in theory, but pushforward of gradient in practice _sig::Val{SIG} grad_buffer::G @@ -147,21 +147,21 @@ struct ForwardOverAnythingHVPPrep{SIG,G,GO,GI,PO,PI} <: HVPPrep{SIG} end function _prepare_hvp_aux( - strict::Val, - ::ForwardOverAnything, - ::DontPrepareInner, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + ::ForwardOverAnything, + ::DontPrepareInner, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Outer pushforward new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) outer_pushforward_prep = prepare_pushforward_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... @@ -174,7 +174,7 @@ function _prepare_hvp_aux( outer(backend), x, tx, - new_contexts...; + new_contexts... ) else nothing @@ -185,15 +185,15 @@ function _prepare_hvp_aux( end function _prepare_hvp_aux( - strict::Val, - ::ForwardOverAnything, - ::PrepareInnerSimple, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + ::ForwardOverAnything, + ::PrepareInnerSimple, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) @@ -228,7 +228,7 @@ function _prepare_hvp_aux( outer(backend), x, tx, - new_contexts_in...; + new_contexts_in... ) else nothing @@ -244,15 +244,15 @@ function _prepare_hvp_aux( end function _prepare_hvp_aux( - strict::Val, - ::ForwardOverAnything, - ::PrepareInnerOverload, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + ::ForwardOverAnything, + ::PrepareInnerOverload, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) @@ -295,7 +295,7 @@ function _prepare_hvp_aux( outer(backend), x, tx, - new_contexts_in...; + new_contexts_in... ) else nothing @@ -311,13 +311,13 @@ function _prepare_hvp_aux( end function hvp( - f::F, - prep::ForwardOverAnythingHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::ForwardOverAnythingHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) @@ -334,14 +334,14 @@ function hvp( end function hvp!( - f::F, - tg::NTuple, - prep::ForwardOverAnythingHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + tg::NTuple, + prep::ForwardOverAnythingHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return _hvp_aux!( inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... @@ -349,15 +349,15 @@ function hvp!( end function _hvp_aux!( - ::InPlaceSupported, - f::F, - tg::NTuple, - prep::ForwardOverAnythingHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + ::InPlaceSupported, + f::F, + tg::NTuple, + prep::ForwardOverAnythingHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} (; grad_buffer, maybe_inner_gradient_in_prep, outer_pushforward_in_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -380,15 +380,15 @@ function _hvp_aux!( end function _hvp_aux!( - ::InPlaceNotSupported, - f::F, - tg::NTuple, - prep::ForwardOverAnythingHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + ::InPlaceNotSupported, + f::F, + tg::NTuple, + prep::ForwardOverAnythingHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -410,13 +410,13 @@ function _hvp_aux!( end function gradient_and_hvp( - f::F, - prep::ForwardOverAnythingHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::ForwardOverAnythingHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) @@ -433,15 +433,15 @@ function gradient_and_hvp( end function gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ForwardOverAnythingHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + grad, + tg::NTuple, + prep::ForwardOverAnythingHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return _gradient_and_hvp_aux!( inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... @@ -449,16 +449,16 @@ function gradient_and_hvp!( end function _gradient_and_hvp_aux!( - ::InPlaceSupported, - f::F, - grad, - tg::NTuple, - prep::ForwardOverAnythingHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + ::InPlaceSupported, + f::F, + grad, + tg::NTuple, + prep::ForwardOverAnythingHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} (; maybe_inner_gradient_in_prep, outer_pushforward_in_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -482,16 +482,16 @@ function _gradient_and_hvp_aux!( end function _gradient_and_hvp_aux!( - ::InPlaceNotSupported, - f::F, - grad, - tg::NTuple, - prep::ForwardOverAnythingHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + ::InPlaceNotSupported, + f::F, + grad, + tg::NTuple, + prep::ForwardOverAnythingHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -515,7 +515,7 @@ end ## Reverse over forward -struct ReverseOverForwardHVPPrep{SIG,G2<:GradientPrep,G1<:GradientPrep} <: HVPPrep{SIG} +struct ReverseOverForwardHVPPrep{SIG, G2 <: GradientPrep, G1 <: GradientPrep} <: HVPPrep{SIG} # gradient of pushforward _sig::Val{SIG} outer_gradient_prep::G2 @@ -523,15 +523,15 @@ struct ReverseOverForwardHVPPrep{SIG,G2<:GradientPrep,G1<:GradientPrep} <: HVPPr end function _prepare_hvp_aux( - strict::Val, - ::ReverseOverForward, - ::InnerPreparationBehavior, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + ::ReverseOverForward, + ::InnerPreparationBehavior, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( @@ -549,13 +549,13 @@ function _prepare_hvp_aux( end function hvp( - f::F, - prep::ReverseOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + f::F, + prep::ReverseOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) (; outer_gradient_prep) = prep rewrap = Rewrap(contexts...) @@ -576,14 +576,14 @@ function hvp( end function hvp!( - f::F, - tg::NTuple, - prep::ReverseOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + tg::NTuple, + prep::ReverseOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) (; outer_gradient_prep) = prep rewrap = Rewrap(contexts...) @@ -605,13 +605,13 @@ function hvp!( end function gradient_and_hvp( - f::F, - prep::ReverseOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::ReverseOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) tg = hvp(f, prep, backend, x, tx, contexts...) grad = gradient(f, prep.gradient_prep, inner(backend), x, contexts...) @@ -619,15 +619,15 @@ function gradient_and_hvp( end function gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ReverseOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + grad, + tg::NTuple, + prep::ReverseOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) hvp!(f, tg, prep, backend, x, tx, contexts...) gradient!(f, grad, prep.gradient_prep, inner(backend), x, contexts...) @@ -636,7 +636,7 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{SIG,G,PO,PI} <: HVPPrep{SIG} +struct ReverseOverReverseHVPPrep{SIG, G, PO, PI} <: HVPPrep{SIG} # pullback of gradient _sig::Val{SIG} grad_buffer::G @@ -645,19 +645,19 @@ struct ReverseOverReverseHVPPrep{SIG,G,PO,PI} <: HVPPrep{SIG} end function _prepare_hvp_aux( - strict::Val, - ::ReverseOverReverse, - ::InnerPreparationBehavior, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; -) where {F,C} + strict::Val, + ::ReverseOverReverse, + ::InnerPreparationBehavior, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) grad_buffer = similar(x) outer_pullback_prep = prepare_pullback_nokwarg( @@ -671,7 +671,7 @@ function _prepare_hvp_aux( outer(backend), x, tx, - new_contexts...; + new_contexts... ) else nothing @@ -682,18 +682,18 @@ function _prepare_hvp_aux( end function hvp( - f::F, - prep::ReverseOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) return pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... @@ -701,14 +701,14 @@ function hvp( end function hvp!( - f::F, - tg::NTuple, - prep::ReverseOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return _hvp_aux!( inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... @@ -716,19 +716,19 @@ function hvp!( end function _hvp_aux!( - ::InPlaceSupported, - f::F, - tg::NTuple, - prep::ReverseOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + ::InPlaceSupported, + f::F, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} (; grad_buffer, outer_pullback_in_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) return pullback!( shuffled_gradient!, @@ -743,19 +743,19 @@ function _hvp_aux!( end function _hvp_aux!( - ::InPlaceNotSupported, - f::F, - tg::NTuple, - prep::ReverseOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + ::InPlaceNotSupported, + f::F, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) return pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... @@ -763,18 +763,18 @@ function _hvp_aux!( end function gradient_and_hvp( - f::F, - prep::ReverseOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) return value_and_pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... @@ -782,15 +782,15 @@ function gradient_and_hvp( end function gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ReverseOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + grad, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, tx, contexts...) return _gradient_and_hvp_aux!( inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... @@ -798,20 +798,20 @@ function gradient_and_hvp!( end function _gradient_and_hvp_aux!( - ::InPlaceSupported, - f::F, - grad, - tg::NTuple, - prep::ReverseOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + ::InPlaceSupported, + f::F, + grad, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} (; outer_pullback_in_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) new_grad, _ = value_and_pullback!( shuffled_gradient!, @@ -827,20 +827,20 @@ function _gradient_and_hvp_aux!( end function _gradient_and_hvp_aux!( - ::InPlaceNotSupported, - f::F, - grad, - tg::NTuple, - prep::ReverseOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} + ::InPlaceNotSupported, + f::F, + grad, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context, C}, + ) where {F, C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) new_grad, _ = value_and_pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 3d5b8c905..1e239df10 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -6,8 +6,8 @@ $(docstring_prepare("second_derivative")) """ function prepare_second_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C}; strict::Val = Val(true) + ) where {F, C} return prepare_second_derivative_nokwarg(strict, f, backend, x, contexts...) end @@ -17,12 +17,12 @@ end $(docstring_prepare!("second_derivative")) """ function prepare!_second_derivative( - f::F, - old_prep::SecondDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}; -) where {F,C} + f::F, + old_prep::SecondDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C} + ) where {F, C} check_prep(f, old_prep, backend, x, contexts...) return prepare_second_derivative_nokwarg( is_strict(old_prep), f, backend, x, contexts... @@ -37,8 +37,8 @@ Compute the second derivative of the function `f` at point `x`. $(docstring_preparation_hint("second_derivative")) """ function second_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_second_derivative_nokwarg(Val(true), f, backend, x, contexts...) return second_derivative(f, prep, backend, x, contexts...) end @@ -51,8 +51,8 @@ Compute the second derivative of the function `f` at point `x`, overwriting `der $(docstring_preparation_hint("second_derivative")) """ function second_derivative!( - f::F, der2, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, der2, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_second_derivative_nokwarg(Val(true), f, backend, x, contexts...) return second_derivative!(f, der2, prep, backend, x, contexts...) end @@ -65,8 +65,8 @@ Compute the value, first derivative and second derivative of the function `f` at $(docstring_preparation_hint("second_derivative")) """ function value_derivative_and_second_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_second_derivative_nokwarg(Val(true), f, backend, x, contexts...) return value_derivative_and_second_derivative(f, prep, backend, x, contexts...) end @@ -79,8 +79,8 @@ Compute the value, first derivative and second derivative of the function `f` at $(docstring_preparation_hint("second_derivative")) """ function value_derivative_and_second_derivative!( - f::F, der, der2, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, der, der2, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} prep = prepare_second_derivative_nokwarg(Val(true), f, backend, x, contexts...) return value_derivative_and_second_derivative!( f, der, der2, prep, backend, x, contexts... @@ -89,18 +89,18 @@ end ## Preparation -struct DerivativeSecondDerivativePrep{SIG,E<:DerivativePrep} <: SecondDerivativePrep{SIG} +struct DerivativeSecondDerivativePrep{SIG, E <: DerivativePrep} <: SecondDerivativePrep{SIG} _sig::Val{SIG} outer_derivative_prep::E end function prepare_second_derivative_nokwarg( - strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {F,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {F, C} _sig = signature(f, backend, x, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) outer_derivative_prep = prepare_derivative_nokwarg( strict, shuffled_derivative, outer(backend), x, new_contexts... @@ -111,17 +111,17 @@ end ## One argument function second_derivative( - f::F, - prep::DerivativeSecondDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::DerivativeSecondDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) return derivative( shuffled_derivative, outer_derivative_prep, outer(backend), x, new_contexts... @@ -129,17 +129,17 @@ function second_derivative( end function value_derivative_and_second_derivative( - f::F, - prep::DerivativeSecondDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + prep::DerivativeSecondDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) y = f(x, map(unwrap, contexts)...) der, der2 = value_and_derivative( @@ -149,18 +149,18 @@ function value_derivative_and_second_derivative( end function second_derivative!( - f::F, - der2, - prep::SecondDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + der2, + prep::SecondDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) return derivative!( shuffled_derivative, der2, outer_derivative_prep, outer(backend), x, new_contexts... @@ -168,19 +168,19 @@ function second_derivative!( end function value_derivative_and_second_derivative!( - f::F, - der, - der2, - prep::SecondDerivativePrep, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}, -) where {F,C} + f::F, + der, + der2, + prep::SecondDerivativePrep, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, C} check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts..., ) y = f(x, map(unwrap, contexts)...) new_der, _ = value_and_derivative!( diff --git a/DifferentiationInterface/src/second_order/second_order.jl b/DifferentiationInterface/src/second_order/second_order.jl index 0045869dc..2a0a9c2f4 100644 --- a/DifferentiationInterface/src/second_order/second_order.jl +++ b/DifferentiationInterface/src/second_order/second_order.jl @@ -16,7 +16,7 @@ Combination of two backends for second-order differentiation. - `outer::AbstractADType`: backend for the outer differentiation - `inner::AbstractADType`: backend for the inner differentiation """ -struct SecondOrder{ADO<:AbstractADType,ADI<:AbstractADType} <: AbstractADType +struct SecondOrder{ADO <: AbstractADType, ADI <: AbstractADType} <: AbstractADType outer::ADO inner::ADI end @@ -26,9 +26,9 @@ function Base.show(io::IO, backend::SecondOrder) io, SecondOrder, "(", - repr(outer(backend); context=io), + repr(outer(backend); context = io), ", ", - repr(inner(backend); context=io), + repr(inner(backend); context = io), ")", ) end diff --git a/DifferentiationInterface/src/utils/batchsize.jl b/DifferentiationInterface/src/utils/batchsize.jl index 1e55721e9..f484a11db 100644 --- a/DifferentiationInterface/src/utils/batchsize.jl +++ b/DifferentiationInterface/src/utils/batchsize.jl @@ -15,13 +15,13 @@ Configuration for the batch size deduced from a backend and a sample array of le - `A::Int`: number of batches `A = div(N, B, RoundUp)` - `B_last::Int`: size of the last batch (if `aligned` is `false`) """ -struct BatchSizeSettings{B,singlebatch,aligned} +struct BatchSizeSettings{B, singlebatch, aligned} N::Int A::Int B_last::Int end -function BatchSizeSettings{B,singlebatch,aligned}(N::Integer) where {B,singlebatch,aligned} +function BatchSizeSettings{B, singlebatch, aligned}(N::Integer) where {B, singlebatch, aligned} B > N > 0 && throw(ArgumentError("Batch size $B larger than input size $N")) if B == N == 0 A = B_last = 0 @@ -29,20 +29,20 @@ function BatchSizeSettings{B,singlebatch,aligned}(N::Integer) where {B,singlebat A = div(N, B, RoundUp) B_last = N % B end - return BatchSizeSettings{B,singlebatch,aligned}(N, A, B_last) + return BatchSizeSettings{B, singlebatch, aligned}(N, A, B_last) end -function BatchSizeSettings{B}(::Val{N}) where {B,N} +function BatchSizeSettings{B}(::Val{N}) where {B, N} singlebatch = B == N aligned = (B == N == 0) || (N % B == 0) - return BatchSizeSettings{B,singlebatch,aligned}(N) + return BatchSizeSettings{B, singlebatch, aligned}(N) end function BatchSizeSettings{B}(N::Integer) where {B} # type-unstable singlebatch = B == N aligned = (B == N == 0) || (N % B == 0) - return BatchSizeSettings{B,singlebatch,aligned}(N) + return BatchSizeSettings{B, singlebatch, aligned}(N) end """ @@ -55,7 +55,7 @@ function pick_batchsize(backend::AbstractADType, N::Integer) B = 1 singlebatch = false aligned = true - return BatchSizeSettings{B,singlebatch,aligned}(N) + return BatchSizeSettings{B, singlebatch, aligned}(N) end """ @@ -71,7 +71,7 @@ function pick_batchsize(backend::AbstractADType, x::AbstractArray) end function check_batchsize_pickable(backend::AbstractADType) - if backend isa SecondOrder + return if backend isa SecondOrder throw( ArgumentError( "You should select the batch size for the inner or outer backend of $backend", @@ -105,8 +105,8 @@ threshold_batchsize(backend::AbstractADType, ::Integer) = backend function threshold_batchsize(backend::AutoSparse, B::Integer) return AutoSparse( threshold_batchsize(dense_ad(backend), B); - sparsity_detector=backend.sparsity_detector, - coloring_algorithm=backend.coloring_algorithm, + sparsity_detector = backend.sparsity_detector, + coloring_algorithm = backend.coloring_algorithm, ) end diff --git a/DifferentiationInterface/src/utils/check.jl b/DifferentiationInterface/src/utils/check.jl index 311da5ef6..0c98fb938 100644 --- a/DifferentiationInterface/src/utils/check.jl +++ b/DifferentiationInterface/src/utils/check.jl @@ -13,7 +13,7 @@ check_available(backend::AutoSparse) = check_available(dense_ad(backend)) function check_available(backend::MixedMode) return check_available(forward_backend(backend)) && - check_available(reverse_backend(backend)) + check_available(reverse_backend(backend)) end check_available(::ADTypes.NoAutoDiff) = throw(ADTypes.NoAutoDiffSelectedError()) diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index f834a90e7..2d2575d01 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -133,17 +133,17 @@ Utility for recording context types of additional arguments (e.g. `Constant` or Useful for second-order differentiation. """ -struct Rewrap{C,T} +struct Rewrap{C, T} context_makers::T - function Rewrap(contexts::Vararg{Context,C}) where {C} + function Rewrap(contexts::Vararg{Context, C}) where {C} context_makers = map(maker, contexts) - return new{C,typeof(context_makers)}(context_makers) + return new{C, typeof(context_makers)}(context_makers) end end (::Rewrap{0})() = () -function (r::Rewrap{C,T})(unannotated_contexts::Vararg{Any,C}) where {C,T} +function (r::Rewrap{C, T})(unannotated_contexts::Vararg{Any, C}) where {C, T} return map(r.context_makers, unannotated_contexts) do maker, c maker(c) end @@ -160,15 +160,15 @@ Closure around a function `f` and a set of tail argument `tail_args` such that (ft::FixTail)(args...) = ft.f(args..., ft.tail_args...) ``` """ -struct FixTail{F,A<:Tuple} +struct FixTail{F, A <: Tuple} f::F tail_args::A - function FixTail(f::F, tail_args::Vararg{Any,N}) where {F,N} - return new{F,typeof(tail_args)}(f, tail_args) + function FixTail(f::F, tail_args::Vararg{Any, N}) where {F, N} + return new{F, typeof(tail_args)}(f, tail_args) end end -function (ft::FixTail)(args::Vararg{Any,N}) where {N} +function (ft::FixTail)(args::Vararg{Any, N}) where {N} return ft.f(args..., ft.tail_args...) end @@ -178,4 +178,4 @@ end Convenience for constructing a [`FixTail`](@ref), with a shortcut when there are no tail arguments. """ @inline fix_tail(f::F) where {F} = f -fix_tail(f::F, args::Vararg{Any,N}) where {F,N} = FixTail(f, args...) +fix_tail(f::F, args::Vararg{Any, N}) where {F, N} = FixTail(f, args...) diff --git a/DifferentiationInterface/src/utils/errors.jl b/DifferentiationInterface/src/utils/errors.jl index 64308ab22..3e1dae5ed 100644 --- a/DifferentiationInterface/src/utils/errors.jl +++ b/DifferentiationInterface/src/utils/errors.jl @@ -1,6 +1,6 @@ required_packages(b::AbstractADType) = required_packages(typeof(b)) -function required_packages(::Type{B}) where {B<:AbstractADType} +function required_packages(::Type{B}) where {B <: AbstractADType} s = string(B) s = chopprefix(s, "ADTypes.") s = chopprefix(s, "Auto") @@ -12,13 +12,13 @@ function required_packages(::Type{B}) where {B<:AbstractADType} end end -function required_packages(::Type{SecondOrder{O,I}}) where {O,I} +function required_packages(::Type{SecondOrder{O, I}}) where {O, I} p1 = required_packages(O) p2 = required_packages(I) return unique(vcat(p1, p2)) end -function required_packages(::Type{MixedMode{F,R}}) where {F,R} +function required_packages(::Type{MixedMode{F, R}}) where {F, R} p1 = required_packages(F) p2 = required_packages(R) return unique(vcat(p1, p2)) diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index 908c32c89..0be85a3ee 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -1,5 +1,5 @@ -stack_vec_col(t::NTuple) = stack(vec, t; dims=2) -stack_vec_row(t::NTuple) = stack(vec, t; dims=1) +stack_vec_col(t::NTuple) = stack(vec, t; dims = 2) +stack_vec_row(t::NTuple) = stack(vec, t; dims = 1) """ ismutable_array(x) @@ -19,7 +19,7 @@ Apply `similar(_, T)` recursively to `x` or its components. Works if `x` is an `AbstractArray` or a (nested) `NTuple` / `NamedTuple` of `AbstractArray`s. """ recursive_similar(x::AbstractArray, ::Type{T}) where {T} = similar(x, T) -function recursive_similar(x::Union{Tuple,NamedTuple}, ::Type{T}) where {T} +function recursive_similar(x::Union{Tuple, NamedTuple}, ::Type{T}) where {T} return map(xi -> recursive_similar(xi, T), x) end diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index 2d0107d96..95ea16609 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -83,19 +83,19 @@ end is_strict(::Prep{Nothing}) = Val(false) is_strict(::Prep) = Val(true) -struct PreparationMismatchError{SIG,EXEC_SIG} <: Exception +struct PreparationMismatchError{SIG, EXEC_SIG} <: Exception format::Vector{Symbol} end function PreparationMismatchError( - ::Type{SIG}, ::Type{EXEC_SIG}; format -) where {SIG,EXEC_SIG} - return PreparationMismatchError{SIG,EXEC_SIG}(format) + ::Type{SIG}, ::Type{EXEC_SIG}; format + ) where {SIG, EXEC_SIG} + return PreparationMismatchError{SIG, EXEC_SIG}(format) end function Base.showerror( - io::IO, e::PreparationMismatchError{SIG,EXEC_SIG} -) where {SIG<:Tuple,EXEC_SIG<:Tuple} + io::IO, e::PreparationMismatchError{SIG, EXEC_SIG} + ) where {SIG <: Tuple, EXEC_SIG <: Tuple} println( io, "PreparationMismatchError (inconsistent types between preparation and execution):", @@ -115,8 +115,8 @@ function Base.showerror( end function signature( - f, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val{S} -) where {C,S} + f, backend::AbstractADType, x, contexts::Vararg{Context, C}; strict::Val{S} + ) where {C, S} if S return Val(typeof((f, backend, x, contexts))) else @@ -125,8 +125,8 @@ function signature( end function signature( - f!, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val{S} -) where {C,S} + f!, y, backend::AbstractADType, x, contexts::Vararg{Context, C}; strict::Val{S} + ) where {C, S} if S return Val(typeof((f!, y, backend, x, contexts))) else @@ -135,8 +135,8 @@ function signature( end function signature( - f, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}; strict::Val{S} -) where {C,S} + f, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context, C}; strict::Val{S} + ) where {C, S} if S return Val(typeof((f, backend, x, t, contexts))) else @@ -145,14 +145,14 @@ function signature( end function signature( - f!, - y, - backend::AbstractADType, - x, - t::NTuple, - contexts::Vararg{Context,C}; - strict::Val{S}, -) where {C,S} + f!, + y, + backend::AbstractADType, + x, + t::NTuple, + contexts::Vararg{Context, C}; + strict::Val{S}, + ) where {C, S} if S return Val(typeof((f!, y, backend, x, t, contexts))) else @@ -161,14 +161,14 @@ function signature( end function check_prep( - f, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {SIG,C} - if SIG !== Nothing + f, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {SIG, C} + return if SIG !== Nothing EXEC_SIG = typeof((f, backend, x, contexts)) if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, EXEC_SIG; format=[:f, :backend, :x, :contexts] + SIG, EXEC_SIG; format = [:f, :backend, :x, :contexts] ), ) end @@ -176,14 +176,14 @@ function check_prep( end function check_prep( - f!, y, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} -) where {SIG,C} - if SIG !== Nothing + f!, y, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context, C} + ) where {SIG, C} + return if SIG !== Nothing EXEC_SIG = typeof((f!, y, backend, x, contexts)) if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :contexts] + SIG, EXEC_SIG; format = [:f!, :y, :backend, :x, :contexts] ), ) end @@ -191,14 +191,14 @@ function check_prep( end function check_prep( - f, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} -) where {SIG,C} - if SIG !== Nothing + f, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context, C} + ) where {SIG, C} + return if SIG !== Nothing EXEC_SIG = typeof((f, backend, x, t, contexts)) if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, EXEC_SIG; format=[:f, :backend, :x, :t, :contexts] + SIG, EXEC_SIG; format = [:f, :backend, :x, :t, :contexts] ), ) end @@ -206,14 +206,14 @@ function check_prep( end function check_prep( - f!, y, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} -) where {SIG,C} - if SIG !== Nothing + f!, y, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context, C} + ) where {SIG, C} + return if SIG !== Nothing EXEC_SIG = typeof((f!, y, backend, x, t, contexts)) if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :t, :contexts] + SIG, EXEC_SIG; format = [:f!, :y, :backend, :x, :t, :contexts] ), ) end diff --git a/DifferentiationInterface/src/utils/sparse.jl b/DifferentiationInterface/src/utils/sparse.jl index f79650503..559c1b850 100644 --- a/DifferentiationInterface/src/utils/sparse.jl +++ b/DifferentiationInterface/src/utils/sparse.jl @@ -5,14 +5,14 @@ Wrapper around [`ADTypes.jacobian_sparsity`](@extref ADTypes.jacobian_sparsity) enabling the allocation of caches with proper element types. """ function jacobian_sparsity_with_contexts( - f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context, C} + ) where {F, C} return jacobian_sparsity(fix_tail(f, map(unwrap, contexts)...), x, detector) end function jacobian_sparsity_with_contexts( - f!::F, y, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} -) where {F,C} + f!::F, y, detector::AbstractSparsityDetector, x, contexts::Vararg{Context, C} + ) where {F, C} return jacobian_sparsity(fix_tail(f!, map(unwrap, contexts)...), y, x, detector) end @@ -22,7 +22,7 @@ end Wrapper around [`ADTypes.hessian_sparsity`](@extref ADTypes.hessian_sparsity) enabling the allocation of caches with proper element types. """ function hessian_sparsity_with_contexts( - f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} -) where {F,C} + f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context, C} + ) where {F, C} return hessian_sparsity(fix_tail(f, map(unwrap, contexts)...), x, detector) end diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index b1f4ed5b5..017f0d47a 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -27,7 +27,7 @@ inplace_support(::ADTypes.NoAutoDiff) = throw(ADTypes.NoAutoDiffSelectedError()) function inplace_support(backend::SecondOrder) if inplace_support(inner(backend)) isa InPlaceSupported && - inplace_support(outer(backend)) isa InPlaceSupported + inplace_support(outer(backend)) isa InPlaceSupported return InPlaceSupported() else return InPlaceNotSupported() @@ -38,7 +38,7 @@ inplace_support(backend::AutoSparse) = inplace_support(dense_ad(backend)) function inplace_support(backend::MixedMode) if Bool(inplace_support(forward_backend(backend))) && - Bool(inplace_support(reverse_backend(backend))) + Bool(inplace_support(reverse_backend(backend))) return InPlaceSupported() else return InPlaceNotSupported() @@ -74,7 +74,7 @@ pushforward_performance(::ForwardOrReverseMode) = PushforwardFast() pushforward_performance(::ReverseMode) = PushforwardSlow() pushforward_performance(::SymbolicMode) = PushforwardFast() -function pushforward_performance(backend::Union{AutoSparse,SecondOrder}) +function pushforward_performance(backend::Union{AutoSparse, SecondOrder}) throw(ArgumentError("Pushforward performance not defined for $backend`.")) end @@ -107,7 +107,7 @@ pullback_performance(::ForwardOrReverseMode) = PullbackFast() pullback_performance(::ReverseMode) = PullbackFast() pullback_performance(::SymbolicMode) = PullbackFast() -function pullback_performance(backend::Union{AutoSparse,SecondOrder}) +function pullback_performance(backend::Union{AutoSparse, SecondOrder}) throw(ArgumentError("Pullback performance not defined for $backend`.")) end @@ -143,7 +143,7 @@ Traits identifying second-order backends that compute HVPs in forward over forwa """ struct ForwardOverForward <: HVPMode end -const ForwardOverAnything = Union{ForwardOverForward,ForwardOverReverse} +const ForwardOverAnything = Union{ForwardOverForward, ForwardOverReverse} """ hvp_mode(backend) diff --git a/DifferentiationInterface/test/Back/ChainRules/zygote.jl b/DifferentiationInterface/test/Back/ChainRules/zygote.jl index ef54db928..9722220f4 100644 --- a/DifferentiationInterface/test/Back/ChainRules/zygote.jl +++ b/DifferentiationInterface/test/Back/ChainRules/zygote.jl @@ -19,13 +19,13 @@ end test_differentiation( AutoChainRules(ZygoteRuleConfig()), default_scenarios(); - excluded=[:second_derivative], - logging=LOGGING, + excluded = [:second_derivative], + logging = LOGGING, ); test_differentiation( AutoChainRules(ZygoteRuleConfig()), - default_scenarios(; include_normal=false, include_constantified=true); - excluded=SECOND_ORDER, - logging=LOGGING, + default_scenarios(; include_normal = false, include_constantified = true); + excluded = SECOND_ORDER, + logging = LOGGING, ); diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 9c655e001..04d9f5b99 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -43,27 +43,27 @@ function differentiatewith_scenarios() end test_differentiation( - [AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)], + [AutoForwardDiff(), AutoZygote(), AutoMooncake(; config = nothing)], differentiatewith_scenarios(); - excluded=SECOND_ORDER, - logging=LOGGING, - testset_name="DI tests", + excluded = SECOND_ORDER, + logging = LOGGING, + testset_name = "DI tests", ) @testset "ChainRules tests" begin @testset for scen in filter(differentiatewith_scenarios()) do scen - DIT.operator(scen) == :pullback - end - ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol=1e-4) + DIT.operator(scen) == :pullback + end + ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol = 1.0e-4) end end; @testset "Mooncake tests" begin @testset for scen in filter(differentiatewith_scenarios()) do scen - DIT.operator(scen) == :pullback - end + DIT.operator(scen) == :pullback + end Mooncake.TestUtils.test_rule( - StableRNG(0), scen.f, scen.x; is_primitive=true, mode=Mooncake.ReverseMode + StableRNG(0), scen.f, scen.x; is_primitive = true, mode = Mooncake.ReverseMode ) end end; @@ -83,25 +83,25 @@ end; @test_throws MooncakeDifferentiateWithError pullback( DifferentiateWith(f_num2tup, AutoFiniteDiff()), - AutoMooncake(; config=nothing), + AutoMooncake(; config = nothing), 1.0, ((2.0,),), ) @test_throws MooncakeDifferentiateWithError pullback( DifferentiateWith(f_vec2tup, AutoFiniteDiff()), - AutoMooncake(; config=nothing), + AutoMooncake(; config = nothing), [1.0], ((2.0,),), ) @test_throws MethodError pullback( DifferentiateWith(f_tup2num, AutoFiniteDiff()), - AutoMooncake(; config=nothing), + AutoMooncake(; config = nothing), (1.0,), (2.0,), ) @test_throws MethodError pullback( DifferentiateWith(f_tup2vec, AutoFiniteDiff()), - AutoMooncake(; config=nothing), + AutoMooncake(; config = nothing), (1.0,), ([2.0],), ) diff --git a/DifferentiationInterface/test/Back/Diffractor/test.jl b/DifferentiationInterface/test/Back/Diffractor/test.jl index 08315b9c3..57be809ea 100644 --- a/DifferentiationInterface/test/Back/Diffractor/test.jl +++ b/DifferentiationInterface/test/Back/Diffractor/test.jl @@ -17,7 +17,7 @@ end test_differentiation( AutoDiffractor(), - default_scenarios(; linalg=false); - excluded=SECOND_ORDER, - logging=LOGGING, + default_scenarios(; linalg = false); + excluded = SECOND_ORDER, + logging = LOGGING, ); diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index cb4fce824..892ac6c78 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -17,15 +17,15 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" backends = [ - AutoEnzyme(; mode=nothing), - AutoEnzyme(; mode=Enzyme.Forward), - AutoEnzyme(; mode=Enzyme.Reverse), - AutoEnzyme(; mode=nothing, function_annotation=Enzyme.Const), + AutoEnzyme(; mode = nothing), + AutoEnzyme(; mode = Enzyme.Forward), + AutoEnzyme(; mode = Enzyme.Reverse), + AutoEnzyme(; mode = nothing, function_annotation = Enzyme.Const), ] duplicated_backends = [ - AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Duplicated), - AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Duplicated), + AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated), + AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Duplicated), ] @testset "Checks" begin @@ -37,33 +37,33 @@ end; @testset "First order" begin test_differentiation( - backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING + backends, default_scenarios(); excluded = SECOND_ORDER, logging = LOGGING ) test_differentiation( backends[1:3], - default_scenarios(; include_normal=false, include_constantified=true); - excluded=SECOND_ORDER, - logging=LOGGING, + default_scenarios(; include_normal = false, include_constantified = true); + excluded = SECOND_ORDER, + logging = LOGGING, ) test_differentiation( backends[2:3], default_scenarios(; - include_normal=false, - include_cachified=true, - include_constantorcachified=true, - use_tuples=true, + include_normal = false, + include_cachified = true, + include_constantorcachified = true, + use_tuples = true, ); - excluded=SECOND_ORDER, - logging=LOGGING, + excluded = SECOND_ORDER, + logging = LOGGING, ) test_differentiation( duplicated_backends, - default_scenarios(; include_normal=false, include_closurified=true); - excluded=SECOND_ORDER, - logging=LOGGING, + default_scenarios(; include_normal = false, include_closurified = true); + excluded = SECOND_ORDER, + logging = LOGGING, ) end @@ -72,25 +72,25 @@ end [ AutoEnzyme(), SecondOrder( - AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Forward) + AutoEnzyme(; mode = Enzyme.Reverse), AutoEnzyme(; mode = Enzyme.Forward) ), ], - default_scenarios(; include_constantified=true, include_cachified=true); - excluded=FIRST_ORDER, - logging=LOGGING, + default_scenarios(; include_constantified = true, include_cachified = true); + excluded = FIRST_ORDER, + logging = LOGGING, ) end @testset "Sparse" begin test_differentiation( - MyAutoSparse.(AutoEnzyme(; function_annotation=Enzyme.Const)), + MyAutoSparse.(AutoEnzyme(; function_annotation = Enzyme.Const)), if VERSION < v"1.11" sparse_scenarios() else filter(s -> s.x isa AbstractVector, sparse_scenarios()) end; - sparsity=true, - logging=LOGGING, + sparsity = true, + logging = LOGGING, ) end @@ -100,10 +100,10 @@ end end test_differentiation( - [AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)], + [AutoEnzyme(; mode = Enzyme.Forward), AutoEnzyme(; mode = Enzyme.Reverse)], filtered_static_scenarios; - excluded=SECOND_ORDER, - logging=LOGGING, + excluded = SECOND_ORDER, + logging = LOGGING, ) end @@ -111,10 +111,10 @@ end # ConstantOrCache without cache f_nocontext(x, p) = x @test I == DifferentiationInterface.jacobian( - f_nocontext, AutoEnzyme(; mode=Enzyme.Forward), rand(10), ConstantOrCache(nothing) + f_nocontext, AutoEnzyme(; mode = Enzyme.Forward), rand(10), ConstantOrCache(nothing) ) @test I == DifferentiationInterface.jacobian( - f_nocontext, AutoEnzyme(; mode=Enzyme.Reverse), rand(10), ConstantOrCache(nothing) + f_nocontext, AutoEnzyme(; mode = Enzyme.Reverse), rand(10), ConstantOrCache(nothing) ) end @@ -153,7 +153,7 @@ end try pushforward( h, - AutoEnzyme(; mode=Enzyme.Forward), + AutoEnzyme(; mode = Enzyme.Forward), [1.0], ([1.0],), Constant([1.0]), @@ -169,8 +169,8 @@ end @testset "Empty arrays" begin test_differentiation( - [AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)], + [AutoEnzyme(; mode = Enzyme.Forward), AutoEnzyme(; mode = Enzyme.Reverse)], empty_scenarios(); - excluded=[:jacobian], + excluded = [:jacobian], ) end; diff --git a/DifferentiationInterface/test/Back/FiniteDiff/benchmark.jl b/DifferentiationInterface/test/Back/FiniteDiff/benchmark.jl index f2783696c..96a0cb61e 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/benchmark.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/benchmark.jl @@ -11,7 +11,7 @@ using Test LOGGING = get(ENV, "CI", "false") == "false" @testset "Benchmarking sparse" begin - filtered_sparse_scenarios = filter(sparse_scenarios(; band_sizes=[])) do scen + filtered_sparse_scenarios = filter(sparse_scenarios(; band_sizes = [])) do scen DIT.function_place(scen) == :in && DIT.operator_place(scen) == :in && scen.x isa AbstractVector && @@ -21,9 +21,9 @@ LOGGING = get(ENV, "CI", "false") == "false" data = benchmark_differentiation( MyAutoSparse(AutoFiniteDiff()), filtered_sparse_scenarios; - benchmark=:prepared, - excluded=SECOND_ORDER, - logging=LOGGING, + benchmark = :prepared, + excluded = SECOND_ORDER, + logging = LOGGING, ) @testset "Analyzing benchmark results" begin @testset "$(row[:scenario])" for row in eachrow(data) diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index 9681bcab5..a7f5ad42f 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -23,31 +23,31 @@ end test_differentiation( AutoFiniteDiff(), default_scenarios(; - include_constantified=true, - include_cachified=true, - include_constantorcachified=true, - use_tuples=true, - include_smaller=true, + include_constantified = true, + include_cachified = true, + include_constantorcachified = true, + use_tuples = true, + include_smaller = true, ); - excluded=[:second_derivative, :hvp], - logging=LOGGING, + excluded = [:second_derivative, :hvp], + logging = LOGGING, ) test_differentiation( - SecondOrder(AutoFiniteDiff(; relstep=1e-5, absstep=1e-5), AutoFiniteDiff()), + SecondOrder(AutoFiniteDiff(; relstep = 1.0e-5, absstep = 1.0e-5), AutoFiniteDiff()), default_scenarios(); - logging=LOGGING, - rtol=1e-2, + logging = LOGGING, + rtol = 1.0e-2, ) test_differentiation( [ - AutoFiniteDiff(; relstep=cbrt(eps(Float64))), - AutoFiniteDiff(; relstep=cbrt(eps(Float64)), absstep=cbrt(eps(Float64))), - AutoFiniteDiff(; dir=0.5), + AutoFiniteDiff(; relstep = cbrt(eps(Float64))), + AutoFiniteDiff(; relstep = cbrt(eps(Float64)), absstep = cbrt(eps(Float64))), + AutoFiniteDiff(; dir = 0.5), ]; - excluded=[:second_derivative, :hvp], - logging=LOGGING, + excluded = [:second_derivative, :hvp], + logging = LOGGING, ) end @@ -55,26 +55,26 @@ end test_differentiation( MyAutoSparse(AutoFiniteDiff()), sparse_scenarios(); - excluded=SECOND_ORDER, - logging=LOGGING, + excluded = SECOND_ORDER, + logging = LOGGING, ) end @testset "Complex" begin - test_differentiation(AutoFiniteDiff(), complex_scenarios(); logging=LOGGING) + test_differentiation(AutoFiniteDiff(), complex_scenarios(); logging = LOGGING) test_differentiation( AutoSparse( AutoFiniteDiff(); - sparsity_detector=DenseSparsityDetector(AutoFiniteDiff(); atol=1e-5), - coloring_algorithm=GreedyColoringAlgorithm(), + sparsity_detector = DenseSparsityDetector(AutoFiniteDiff(); atol = 1.0e-5), + coloring_algorithm = GreedyColoringAlgorithm(), ), complex_sparse_scenarios(); - logging=LOGGING, + logging = LOGGING, ) end; @testset "Step size" begin # fix 811 - backend = AutoFiniteDiff(; absstep=1000, relstep=0.1) + backend = AutoFiniteDiff(; absstep = 1000, relstep = 0.1) preps = [ prepare_pushforward(identity, backend, 1.0, (1.0,)), prepare_pushforward(copyto!, [0.0], backend, [1.0], ([1.0],)), @@ -94,7 +94,7 @@ end; @test prep.relstep_g == 0.1 @test prep.relstep_h == 0.1 - backend = AutoFiniteDiff(; relstep=0.1) + backend = AutoFiniteDiff(; relstep = 0.1) preps = [ prepare_pushforward(identity, backend, 1.0, (1.0,)), prepare_pushforward(copyto!, [0.0], backend, [1.0], ([1.0],)), diff --git a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl index c55ee3bd2..7157a043c 100644 --- a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl @@ -10,7 +10,7 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1))] +for backend in [AutoFiniteDifferences(; fdm = FiniteDifferences.central_fdm(3, 1))] @test check_available(backend) @test !check_inplace(backend) @test DifferentiationInterface.inner_preparation_behavior(backend) isa @@ -18,10 +18,10 @@ for backend in [AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)) end test_differentiation( - AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), + AutoFiniteDifferences(; fdm = FiniteDifferences.central_fdm(3, 1)), default_scenarios(; - include_constantified=true, include_cachified=true, use_tuples=true + include_constantified = true, include_cachified = true, use_tuples = true ); - excluded=SECOND_ORDER, - logging=LOGGING, + excluded = SECOND_ORDER, + logging = LOGGING, ); diff --git a/DifferentiationInterface/test/Back/ForwardDiff/benchmark.jl b/DifferentiationInterface/test/Back/ForwardDiff/benchmark.jl index aea81d03b..256c8aaa2 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/benchmark.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/benchmark.jl @@ -12,15 +12,15 @@ using Test LOGGING = get(ENV, "CI", "false") == "false" @testset verbose = true "Benchmarking static" begin - filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen + filtered_static_scenarios = filter(static_scenarios(; include_batchified = false)) do scen DIT.function_place(scen) == :out && DIT.operator_place(scen) == :out end data = benchmark_differentiation( AutoForwardDiff(), filtered_static_scenarios; - benchmark=:prepared, - excluded=[:hessian, :pullback], # TODO: figure this out - logging=LOGGING, + benchmark = :prepared, + excluded = [:hessian, :pullback], # TODO: figure this out + logging = LOGGING, ) @testset "Analyzing benchmark results" begin @testset "$(row[:scenario])" for row in eachrow(data) @@ -30,7 +30,7 @@ LOGGING = get(ENV, "CI", "false") == "false" end @testset "Benchmarking sparse" begin - filtered_sparse_scenarios = filter(sparse_scenarios(; band_sizes=[])) do scen + filtered_sparse_scenarios = filter(sparse_scenarios(; band_sizes = [])) do scen DIT.function_place(scen) == :in && DIT.operator_place(scen) == :in && scen.x isa AbstractVector && @@ -40,9 +40,9 @@ end data = benchmark_differentiation( MyAutoSparse(AutoForwardDiff()), filtered_sparse_scenarios; - benchmark=:prepared, - excluded=SECOND_ORDER, - logging=LOGGING, + benchmark = :prepared, + excluded = SECOND_ORDER, + logging = LOGGING, ) @testset "Analyzing benchmark results" begin @testset "$(row[:scenario])" for row in eachrow(data) diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 42448a7c0..ef501aec1 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -20,8 +20,8 @@ struct MyTag end backends = [ AutoForwardDiff(), - AutoForwardDiff(; chunksize=5), - AutoForwardDiff(; tag=ForwardDiff.Tag(MyTag(), Float64)), + AutoForwardDiff(; chunksize = 5), + AutoForwardDiff(; tag = ForwardDiff.Tag(MyTag(), Float64)), ] for backend in backends @@ -31,67 +31,69 @@ end @testset "Dense" begin test_differentiation( - backends, default_scenarios(; include_constantified=true); logging=LOGGING + backends, default_scenarios(; include_constantified = true); logging = LOGGING ) test_differentiation( AutoForwardDiff(), default_scenarios(; - include_normal=false, - include_batchified=false, - include_cachified=true, - include_constantorcachified=true, - use_tuples=true, - include_smaller=true, + include_normal = false, + include_batchified = false, + include_cachified = true, + include_constantorcachified = true, + use_tuples = true, + include_smaller = true, ); - logging=LOGGING, + logging = LOGGING, ) test_differentiation( AutoForwardDiff(); - correctness=false, - type_stability=safetypestab(:prepared), - logging=LOGGING, + correctness = false, + type_stability = safetypestab(:prepared), + logging = LOGGING, ) test_differentiation( - AutoForwardDiff(; chunksize=5); - correctness=false, - type_stability=safetypestab(:full), - excluded=[:hessian], - logging=LOGGING, + AutoForwardDiff(; chunksize = 5); + correctness = false, + type_stability = safetypestab(:full), + excluded = [:hessian], + logging = LOGGING, ) end @testset "Sparse" begin test_differentiation( - MyAutoSparse(AutoForwardDiff()), default_scenarios(); logging=LOGGING + MyAutoSparse(AutoForwardDiff()), default_scenarios(); logging = LOGGING ) test_differentiation( MyAutoSparse(AutoForwardDiff()), - sparse_scenarios(; include_constantified=true); - sparsity=true, - logging=LOGGING, + sparse_scenarios(; include_constantified = true); + sparsity = true, + logging = LOGGING, ) end @testset "Weird" begin - test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING) - test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING) + test_differentiation(AutoForwardDiff(), component_scenarios(); logging = LOGGING) + test_differentiation(AutoForwardDiff(), static_scenarios(); logging = LOGGING) test_differentiation( - DI.AutoForwardFromPrimitive(AutoForwardDiff()), gpu_scenarios(); logging=LOGGING + DI.AutoForwardFromPrimitive(AutoForwardDiff()), gpu_scenarios(); logging = LOGGING ) @testset "Batch size" begin @test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7} - @test DI.pick_batchsize(AutoForwardDiff(; chunksize=5), rand(7)) isa + @test DI.pick_batchsize(AutoForwardDiff(; chunksize = 5), rand(7)) isa DI.BatchSizeSettings{5} @test (@inferred DI.pick_batchsize(AutoForwardDiff(), @SVector(rand(7)))) isa DI.BatchSizeSettings{7} - @test (@inferred DI.pick_batchsize( - AutoForwardDiff(; chunksize=5), @SVector(rand(7)) - )) isa DI.BatchSizeSettings{5} + @test ( + @inferred DI.pick_batchsize( + AutoForwardDiff(; chunksize = 5), @SVector(rand(7)) + ) + ) isa DI.BatchSizeSettings{5} end end @@ -103,22 +105,22 @@ end x = 1.0 y = [1.0, 1.0] @test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == - ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,1} + ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1} @test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) == - Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} + Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}} # Gradient x = [1.0, 1.0] @test DI.overloaded_input_type(prepare_gradient(sum, backend, x)) == - Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(sum),Float64},Float64,2}} + Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(sum), Float64}, Float64, 2}} # Jacobian x = [1.0, 0.0, 0.0] @test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) == - ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,3} + ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 3} @test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) == - Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,3}} + Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 3}} @test DI.overloaded_input_type( prepare_jacobian(copyto!, similar(x), sparse_backend, x) - ) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} + ) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}} end; diff --git a/DifferentiationInterface/test/Back/GTPSA/test.jl b/DifferentiationInterface/test/Back/GTPSA/test.jl index 209f995e2..63d7e7ea5 100644 --- a/DifferentiationInterface/test/Back/GTPSA/test.jl +++ b/DifferentiationInterface/test/Back/GTPSA/test.jl @@ -18,17 +18,17 @@ end # Test no Descriptor (use context) test_differentiation( AutoGTPSA(), - default_scenarios(; include_constantified=true); - type_stability=safetypestab(:full), - logging=LOGGING, + default_scenarios(; include_constantified = true); + type_stability = safetypestab(:full), + logging = LOGGING, ); # Test with Descriptor: d1 = GTPSA.Descriptor(20, 2) # 20 variables to 2nd order -test_differentiation(AutoGTPSA(d1); type_stability=safetypestab(:full), logging=LOGGING); +test_differentiation(AutoGTPSA(d1); type_stability = safetypestab(:full), logging = LOGGING); # Test with Descriptor using varying orders vos = 2 * ones(Int, 20) vos[1] = 3 d2 = GTPSA.Descriptor(vos, 3) -test_differentiation(AutoGTPSA(d2); type_stability=safetypestab(:full), logging=LOGGING); +test_differentiation(AutoGTPSA(d2); type_stability = safetypestab(:full), logging = LOGGING); diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index b695179f1..3b61e3547 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -11,9 +11,9 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" backends = [ - AutoMooncake(; config=nothing), - AutoMooncake(; config=Mooncake.Config()), - AutoMooncakeForward(; config=nothing), + AutoMooncake(; config = nothing), + AutoMooncake(; config = Mooncake.Config()), + AutoMooncakeForward(; config = nothing), ] for backend in backends @@ -24,14 +24,14 @@ end test_differentiation( backends, default_scenarios(; - include_constantified=true, include_cachified=true, use_tuples=true + include_constantified = true, include_cachified = true, use_tuples = true ); - excluded=SECOND_ORDER, - logging=LOGGING, + excluded = SECOND_ORDER, + logging = LOGGING, ); @testset "NamedTuples" begin - ps = (; A=rand(5), B=rand(5)) + ps = (; A = rand(5), B = rand(5)) myfun(ps) = sum(ps.A .* ps.B) grad = gradient(myfun, backends[1], ps) @test grad.A == ps.B diff --git a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl index 34b59d46a..e1482731f 100644 --- a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl @@ -15,8 +15,8 @@ LOGGING = get(ENV, "CI", "false") == "false" struct MyTag end backends = [ - AutoPolyesterForwardDiff(; tag=ForwardDiff.Tag(MyTag(), Float64)), # - AutoPolyesterForwardDiff(; chunksize=2), + AutoPolyesterForwardDiff(; tag = ForwardDiff.Tag(MyTag(), Float64)), # + AutoPolyesterForwardDiff(; chunksize = 2), ] for backend in backends @@ -29,14 +29,14 @@ end test_differentiation( backends, default_scenarios(; - include_constantified=true, include_cachified=true, use_tuples=true + include_constantified = true, include_cachified = true, use_tuples = true ); - logging=LOGGING, + logging = LOGGING, ); @testset "Batch size" begin @test DI.pick_batchsize(AutoPolyesterForwardDiff(), 10) == DI.pick_batchsize(AutoForwardDiff(), 10) - @test DI.pick_batchsize(AutoPolyesterForwardDiff(; chunksize=3), rand(10)) == - DI.pick_batchsize(AutoForwardDiff(; chunksize=3), rand(10)) + @test DI.pick_batchsize(AutoPolyesterForwardDiff(; chunksize = 3), rand(10)) == + DI.pick_batchsize(AutoForwardDiff(; chunksize = 3), rand(10)) end diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index c1b902d14..2043c60f2 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -13,7 +13,7 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -backends = [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] +backends = [AutoReverseDiff(; compile = false), AutoReverseDiff(; compile = true)] second_order_backends = [SecondOrder(AutoForwardDiff(), AutoReverseDiff())] for backend in vcat(backends, second_order_backends) @@ -25,12 +25,12 @@ end test_differentiation( vcat(backends, second_order_backends), - default_scenarios(; include_constantified=true); - logging=LOGGING, + default_scenarios(; include_constantified = true); + logging = LOGGING, ); test_differentiation( - backends, static_scenarios(; include_constantified=true); logging=LOGGING + backends, static_scenarios(; include_constantified = true); logging = LOGGING ); ## Sparse @@ -38,8 +38,8 @@ test_differentiation( test_differentiation( MyAutoSparse.(vcat(backends, second_order_backends)), sparse_scenarios(); - sparsity=true, - logging=LOGGING, + sparsity = true, + logging = LOGGING, ); @testset verbose = true "Overloaded inputs" begin @@ -48,16 +48,16 @@ test_differentiation( # Derivative x = 1.0 @test_skip DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == - ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} + ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}} # Gradient x = [1.0; 0.0; 0.0] @test DI.overloaded_input_type(prepare_gradient(sum, backend, x)) == - ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} + ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}} # Jacobian @test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) == - ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} + ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}} @test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) == - ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} + ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}} end; diff --git a/DifferentiationInterface/test/Back/SparsityDetector/test.jl b/DifferentiationInterface/test/Back/SparsityDetector/test.jl index 1ea6a975a..f1de5c11b 100644 --- a/DifferentiationInterface/test/Back/SparsityDetector/test.jl +++ b/DifferentiationInterface/test/Back/SparsityDetector/test.jl @@ -27,10 +27,10 @@ g(x::AbstractVector) = dot(x, Hc, x) g(x::AbstractMatrix) = g(vec(x)) @testset verbose = true "$(typeof(backend))" for backend in - [AutoForwardDiff(), AutoReverseDiff()] - @test_throws ArgumentError DenseSparsityDetector(backend; atol=1e-5, method=:random) + [AutoForwardDiff(), AutoReverseDiff()] + @test_throws ArgumentError DenseSparsityDetector(backend; atol = 1.0e-5, method = :random) @testset "$method" for method in (:iterative, :direct) - detector = DenseSparsityDetector(backend; atol=1e-5, method) + detector = DenseSparsityDetector(backend; atol = 1.0e-5, method) string(detector) for (x, y) in ((rand(20), zeros(10)), (rand(2, 10), zeros(5, 2))) @test Jc == jacobian_sparsity(f, x, detector) diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl index 43bf80cb8..86d903faa 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl @@ -20,16 +20,16 @@ end test_differentiation( AutoFastDifferentiation(), default_scenarios(; - include_constantified=true, include_cachified=true, use_tuples=false + include_constantified = true, include_cachified = true, use_tuples = false ); - logging=LOGGING, + logging = LOGGING, ); test_differentiation( AutoSparse(AutoFastDifferentiation()), - sparse_scenarios(; band_sizes=0:-1); - sparsity=true, - logging=LOGGING, + sparse_scenarios(; band_sizes = 0:-1); + sparsity = true, + logging = LOGGING, ); @testset "SparseMatrixColorings access" begin diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl index 6d3008004..5fe04d8c9 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl @@ -18,20 +18,20 @@ for backend in [AutoSymbolics(), AutoSparse(AutoSymbolics())] end test_differentiation( - AutoSymbolics(), default_scenarios(; include_constantified=true); logging=LOGGING + AutoSymbolics(), default_scenarios(; include_constantified = true); logging = LOGGING ); test_differentiation( AutoSymbolics(), - default_scenarios(; include_normal=false, include_cachified=true, use_tuples=false); - logging=LOGGING, + default_scenarios(; include_normal = false, include_cachified = true, use_tuples = false); + logging = LOGGING, ); test_differentiation( AutoSparse(AutoSymbolics()), - sparse_scenarios(; band_sizes=0:-1); - sparsity=true, - logging=LOGGING, + sparse_scenarios(; band_sizes = 0:-1); + sparsity = true, + logging = LOGGING, ); @testset "SparseMatrixColorings access" begin diff --git a/DifferentiationInterface/test/Back/Tracker/test.jl b/DifferentiationInterface/test/Back/Tracker/test.jl index 0131a183e..fe54c6aaf 100644 --- a/DifferentiationInterface/test/Back/Tracker/test.jl +++ b/DifferentiationInterface/test/Back/Tracker/test.jl @@ -17,7 +17,7 @@ end test_differentiation( AutoTracker(), - default_scenarios(; include_constantified=true); - excluded=SECOND_ORDER, - logging=LOGGING, + default_scenarios(; include_constantified = true); + excluded = SECOND_ORDER, + logging = LOGGING, ); diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 882777e20..f9dc8c78f 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -28,24 +28,24 @@ end test_differentiation( backends, default_scenarios(; - include_constantified=true, include_cachified=true, use_tuples=true + include_constantified = true, include_cachified = true, use_tuples = true ); - excluded=[:second_derivative], - logging=LOGGING, + excluded = [:second_derivative], + logging = LOGGING, ) - test_differentiation(second_order_backends; logging=LOGGING) + test_differentiation(second_order_backends; logging = LOGGING) test_differentiation( backends[1], vcat(component_scenarios(), gpu_scenarios()); - excluded=SECOND_ORDER, - logging=LOGGING, + excluded = SECOND_ORDER, + logging = LOGGING, ) end test_differentiation( - AutoZygote(), complex_scenarios(); excluded=[:gradient, :jacobian], logging=LOGGING + AutoZygote(), complex_scenarios(); excluded = [:gradient, :jacobian], logging = LOGGING ); ## Sparse @@ -53,8 +53,8 @@ test_differentiation( @testset "Sparse" begin test_differentiation( MyAutoSparse.(vcat(backends, second_order_backends)), - sparse_scenarios(; band_sizes=0:-1); - sparsity=true, - logging=LOGGING, + sparse_scenarios(; band_sizes = 0:-1); + sparsity = true, + logging = LOGGING, ) end diff --git a/DifferentiationInterface/test/Core/Internals/_formalities.jl b/DifferentiationInterface/test/Core/Internals/_formalities.jl index 475676702..b27028d4c 100644 --- a/DifferentiationInterface/test/Core/Internals/_formalities.jl +++ b/DifferentiationInterface/test/Core/Internals/_formalities.jl @@ -9,11 +9,11 @@ using SparseMatrixColorings using SparseArrays @testset "Aqua" begin - Aqua.test_all(DifferentiationInterface; ambiguities=false, undocumented_names=true) + Aqua.test_all(DifferentiationInterface; ambiguities = false, undocumented_names = true) end @testset "JET" begin - JET.test_package(DifferentiationInterface; target_defined_modules=true) + JET.test_package(DifferentiationInterface; target_defined_modules = true) end @testset "Documentation" begin @@ -29,7 +29,7 @@ end @test check_all_qualified_accesses_via_owners(DifferentiationInterface) === nothing @test check_no_self_qualified_accesses(DifferentiationInterface) === nothing if VERSION >= v"1.11" - @test check_all_explicit_imports_are_public(DifferentiationInterface;) === nothing + @test check_all_explicit_imports_are_public(DifferentiationInterface) === nothing @test_skip check_all_qualified_accesses_are_public(DifferentiationInterface) === nothing end diff --git a/DifferentiationInterface/test/Core/Internals/batchsize.jl b/DifferentiationInterface/test/Core/Internals/batchsize.jl index 52cc24b91..a97ddbbaa 100644 --- a/DifferentiationInterface/test/Core/Internals/batchsize.jl +++ b/DifferentiationInterface/test/Core/Internals/batchsize.jl @@ -13,9 +13,9 @@ using Test BSS = BatchSizeSettings @testset "Default" begin - @test (@inferred pick_batchsize(AutoZygote(), zeros(0))) isa BSS{1,false,true} - @test (@inferred pick_batchsize(AutoZygote(), zeros(2))) isa BSS{1,false,true} - @test (@inferred pick_batchsize(AutoZygote(), zeros(100))) isa BSS{1,false,true} + @test (@inferred pick_batchsize(AutoZygote(), zeros(0))) isa BSS{1, false, true} + @test (@inferred pick_batchsize(AutoZygote(), zeros(2))) isa BSS{1, false, true} + @test (@inferred pick_batchsize(AutoZygote(), zeros(100))) isa BSS{1, false, true} @test_throws ArgumentError pick_batchsize(AutoSparse(AutoZygote()), zeros(2)) @test_throws ArgumentError pick_batchsize( SecondOrder(AutoZygote(), AutoZygote()), zeros(2) @@ -26,75 +26,79 @@ BSS = BatchSizeSettings end @testset "SimpleFiniteDiff (adaptive)" begin - @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(0))) isa BSS{1,false,true} - @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(2))) isa BSS{2,true,true} - @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(6))) isa BSS{6,true,true} - @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(12))) isa BSS{12,true,true} - @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(24))) isa BSS{12,false,true} - @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(100))) isa BSS{12,false,false} + @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(0))) isa BSS{1, false, true} + @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(2))) isa BSS{2, true, true} + @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(6))) isa BSS{6, true, true} + @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(12))) isa BSS{12, true, true} + @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(24))) isa BSS{12, false, true} + @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(100))) isa BSS{12, false, false} @test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(0)))) isa - BSS{0,true,true} + BSS{0, true, true} @test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(2)))) isa - BSS{2,true,true} + BSS{2, true, true} @test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(6)))) isa - BSS{6,true,true} + BSS{6, true, true} @test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(100)))) isa - BSS{100,true,true} + BSS{100, true, true} end @testset "SimpleFiniteDiff (fixed)" begin - @test_throws ArgumentError pick_batchsize(AutoSimpleFiniteDiff(; chunksize=4), zeros(2)) + @test_throws ArgumentError pick_batchsize(AutoSimpleFiniteDiff(; chunksize = 4), zeros(2)) @test_throws ArgumentError pick_batchsize( - AutoSimpleFiniteDiff(; chunksize=4), @SVector(zeros(2)) + AutoSimpleFiniteDiff(; chunksize = 4), @SVector(zeros(2)) ) - @test pick_batchsize(AutoSimpleFiniteDiff(; chunksize=4), zeros(6)) isa BSS{4} - @test pick_batchsize(AutoSimpleFiniteDiff(; chunksize=4), zeros(100)) isa BSS{4} - BSS{4,true,true} - @test pick_batchsize(AutoSimpleFiniteDiff(; chunksize=4), zeros(99)) isa BSS{4} - BSS{4,true,false} - @test (@inferred pick_batchsize( - AutoSimpleFiniteDiff(; chunksize=4), @SVector(zeros(6)) - )) isa BSS{4,false,false} - @test (@inferred pick_batchsize( - AutoSimpleFiniteDiff(; chunksize=4), @SVector(zeros(100)) - )) isa BSS{4,false,true} + @test pick_batchsize(AutoSimpleFiniteDiff(; chunksize = 4), zeros(6)) isa BSS{4} + @test pick_batchsize(AutoSimpleFiniteDiff(; chunksize = 4), zeros(100)) isa BSS{4} + BSS{4, true, true} + @test pick_batchsize(AutoSimpleFiniteDiff(; chunksize = 4), zeros(99)) isa BSS{4} + BSS{4, true, false} + @test ( + @inferred pick_batchsize( + AutoSimpleFiniteDiff(; chunksize = 4), @SVector(zeros(6)) + ) + ) isa BSS{4, false, false} + @test ( + @inferred pick_batchsize( + AutoSimpleFiniteDiff(; chunksize = 4), @SVector(zeros(100)) + ) + ) isa BSS{4, false, true} end @testset "Thresholding" begin @test threshold_batchsize(AutoSimpleFiniteDiff(), 2) isa AutoSimpleFiniteDiff{nothing} - @test threshold_batchsize(AutoSimpleFiniteDiff(; chunksize=4), 2) isa + @test threshold_batchsize(AutoSimpleFiniteDiff(; chunksize = 4), 2) isa AutoSimpleFiniteDiff{2} - @test threshold_batchsize(AutoSimpleFiniteDiff(; chunksize=4), 6) isa + @test threshold_batchsize(AutoSimpleFiniteDiff(; chunksize = 4), 6) isa AutoSimpleFiniteDiff{4} - @test threshold_batchsize(AutoSparse(AutoSimpleFiniteDiff(; chunksize=4)), 2) isa + @test threshold_batchsize(AutoSparse(AutoSimpleFiniteDiff(; chunksize = 4)), 2) isa AutoSparse{<:AutoSimpleFiniteDiff{2}} @test threshold_batchsize( SecondOrder( - AutoSimpleFiniteDiff(; chunksize=4), AutoSimpleFiniteDiff(; chunksize=3) + AutoSimpleFiniteDiff(; chunksize = 4), AutoSimpleFiniteDiff(; chunksize = 3) ), 6, - ) isa SecondOrder{<:AutoSimpleFiniteDiff{4},<:AutoSimpleFiniteDiff{3}} + ) isa SecondOrder{<:AutoSimpleFiniteDiff{4}, <:AutoSimpleFiniteDiff{3}} @test threshold_batchsize( SecondOrder( - AutoSimpleFiniteDiff(; chunksize=4), AutoSimpleFiniteDiff(; chunksize=3) + AutoSimpleFiniteDiff(; chunksize = 4), AutoSimpleFiniteDiff(; chunksize = 3) ), 2, - ) isa SecondOrder{<:AutoSimpleFiniteDiff{2},<:AutoSimpleFiniteDiff{2}} + ) isa SecondOrder{<:AutoSimpleFiniteDiff{2}, <:AutoSimpleFiniteDiff{2}} @test threshold_batchsize( SecondOrder( - AutoSimpleFiniteDiff(; chunksize=1), AutoSimpleFiniteDiff(; chunksize=3) + AutoSimpleFiniteDiff(; chunksize = 1), AutoSimpleFiniteDiff(; chunksize = 3) ), 2, - ) isa SecondOrder{<:AutoSimpleFiniteDiff{1},<:AutoSimpleFiniteDiff{2}} + ) isa SecondOrder{<:AutoSimpleFiniteDiff{1}, <:AutoSimpleFiniteDiff{2}} @test threshold_batchsize( SecondOrder( - AutoSimpleFiniteDiff(; chunksize=4), AutoSimpleFiniteDiff(; chunksize=1) + AutoSimpleFiniteDiff(; chunksize = 4), AutoSimpleFiniteDiff(; chunksize = 1) ), 2, - ) isa SecondOrder{<:AutoSimpleFiniteDiff{2},<:AutoSimpleFiniteDiff{1}} + ) isa SecondOrder{<:AutoSimpleFiniteDiff{2}, <:AutoSimpleFiniteDiff{1}} @test threshold_batchsize( - MixedMode(AutoSimpleFiniteDiff(; chunksize=4), AutoZygote()), 2 - ) isa MixedMode{<:AutoSimpleFiniteDiff{2},<:AutoZygote} + MixedMode(AutoSimpleFiniteDiff(; chunksize = 4), AutoZygote()), 2 + ) isa MixedMode{<:AutoSimpleFiniteDiff{2}, <:AutoZygote} end @testset "Reasonable" begin diff --git a/DifferentiationInterface/test/Core/Internals/context.jl b/DifferentiationInterface/test/Core/Internals/context.jl index 6ac8be0f6..e056f2908 100644 --- a/DifferentiationInterface/test/Core/Internals/context.jl +++ b/DifferentiationInterface/test/Core/Internals/context.jl @@ -17,4 +17,4 @@ r = @inferred Rewrap() contexts = (Constant(1.0), Cache([2.0])) r = @inferred Rewrap(contexts...) @test (@inferred r(3.0, [4.0])) == (Constant(3.0), Cache([4.0])) -@test (@inferred r(3, [4.0f0])) isa Tuple{Constant{Int},Cache{Vector{Float32}}} +@test (@inferred r(3, [4.0f0])) isa Tuple{Constant{Int}, Cache{Vector{Float32}}} diff --git a/DifferentiationInterface/test/Core/Internals/display.jl b/DifferentiationInterface/test/Core/Internals/display.jl index 9998d7b68..316fa6921 100644 --- a/DifferentiationInterface/test/Core/Internals/display.jl +++ b/DifferentiationInterface/test/Core/Internals/display.jl @@ -6,7 +6,7 @@ using Test backend = SecondOrder(AutoForwardDiff(), AutoZygote()) @test string(backend) == "SecondOrder(AutoForwardDiff(), AutoZygote())" -detector = DenseSparsityDetector(AutoForwardDiff(); atol=1e-23) +detector = DenseSparsityDetector(AutoForwardDiff(); atol = 1.0e-23) @test string(detector) == "DenseSparsityDetector(AutoForwardDiff(); atol=1.0e-23, method=:iterative)" diff --git a/DifferentiationInterface/test/Core/Internals/linalg.jl b/DifferentiationInterface/test/Core/Internals/linalg.jl index 0bb63d69d..bc12b1b17 100644 --- a/DifferentiationInterface/test/Core/Internals/linalg.jl +++ b/DifferentiationInterface/test/Core/Internals/linalg.jl @@ -5,8 +5,8 @@ using Test @testset "Recursive similar" begin @test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32} @test recursive_similar((ones(Int, 2), ones(Bool, 3, 4)), Float32) isa - Tuple{Vector{Float32},Matrix{Float32}} - @test recursive_similar((a=ones(Int, 2), b=(ones(Bool, 3, 4),)), Float32) isa + Tuple{Vector{Float32}, Matrix{Float32}} + @test recursive_similar((a = ones(Int, 2), b = (ones(Bool, 3, 4),)), Float32) isa @NamedTuple{a::Vector{Float32}, b::Tuple{Matrix{Float32}}} @test_throws MethodError recursive_similar(1, Float32) end diff --git a/DifferentiationInterface/test/Core/Internals/signature.jl b/DifferentiationInterface/test/Core/Internals/signature.jl index 3c326c4c7..1157d6eb2 100644 --- a/DifferentiationInterface/test/Core/Internals/signature.jl +++ b/DifferentiationInterface/test/Core/Internals/signature.jl @@ -12,7 +12,7 @@ c = 2.0 @testset "Out of place, no tangents" begin prep = prepare_derivative(f, backend, x, Constant(c)) - prep_chill = prepare_derivative(f, backend, x, Constant(c); strict=Val(false)) + prep_chill = prepare_derivative(f, backend, x, Constant(c); strict = Val(false)) @test_throws MethodError derivative(nothing, prep_chill, backend, x, Constant(c)) @@ -69,7 +69,7 @@ end @testset "In place, no tangents" begin prep = prepare_derivative(f!, y, backend, x) - prep_chill = prepare_derivative(f!, y, backend, x; strict=Val(false)) + prep_chill = prepare_derivative(f!, y, backend, x; strict = Val(false)) @test_throws MethodError derivative(nothing, y, prep_chill, backend, x, Constant(c)) @@ -87,7 +87,7 @@ end @testset "Out of place, with tangents" begin prep = prepare_pushforward(f, backend, x, (x,), Constant(c)) - prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(false)) + prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c); strict = Val(false)) @test_throws MethodError pushforward(nothing, prep_chill, backend, x, (x,)) @@ -106,7 +106,7 @@ end @testset "In place, with tangents" begin prep = prepare_pushforward(f!, y, backend, x, (x,)) prep_chill = prepare_pushforward( - f!, y, backend, x, (x,), Constant(c); strict=Val(false) + f!, y, backend, x, (x,), Constant(c); strict = Val(false) ) @test_throws MethodError pushforward(nothing, y, prep_chill, backend, x, (x,)) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 6c5595475..8ebf4bf95 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -11,38 +11,38 @@ using Test LOGGING = get(ENV, "CI", "false") == "false" backends = [ # - AutoSimpleFiniteDiff(; chunksize=5), - AutoForwardFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), - AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), + AutoSimpleFiniteDiff(; chunksize = 5), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(; chunksize = 4)), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize = 4)), ] second_order_backends = [ # SecondOrder( - AutoForwardFromPrimitive(AutoSimpleFiniteDiff(; chunksize=5)), - AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(; chunksize = 5)), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize = 4)), ), SecondOrder( - AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=5)), - AutoForwardFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize = 5)), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(; chunksize = 4)), ), ] second_order_hvp_backends = [ # SecondOrder( - AutoReverseFromPrimitive(AutoSimpleFiniteDiff(); inplace=false), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff(); inplace = false), AutoForwardFromPrimitive(AutoSimpleFiniteDiff()), ), SecondOrder( - AutoForwardFromPrimitive(AutoSimpleFiniteDiff(); inplace=false), - AutoReverseFromPrimitive(AutoSimpleFiniteDiff();), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(); inplace = false), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), ), SecondOrder( - AutoForwardFromPrimitive(AutoSimpleFiniteDiff(); inplace=false), - AutoForwardFromPrimitive(AutoSimpleFiniteDiff();), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(); inplace = false), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff()), ), SecondOrder( - AutoReverseFromPrimitive(AutoSimpleFiniteDiff(); inplace=false), - AutoReverseFromPrimitive(AutoSimpleFiniteDiff();), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff(); inplace = false), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), ), ] @@ -63,25 +63,25 @@ end @testset "Dense" begin test_differentiation( vcat(backends, second_order_backends), - default_scenarios(; include_constantified=true, include_smaller=true); - logging=LOGGING, + default_scenarios(; include_constantified = true, include_smaller = true); + logging = LOGGING, ) test_differentiation( second_order_hvp_backends, - default_scenarios(; include_constantorcachified=true); - excluded=vcat(FIRST_ORDER, :hessian, :second_derivative), - logging=LOGGING, + default_scenarios(; include_constantorcachified = true); + excluded = vcat(FIRST_ORDER, :hessian, :second_derivative), + logging = LOGGING, ) - test_differentiation(backends, complex_scenarios(); logging=LOGGING) + test_differentiation(backends, complex_scenarios(); logging = LOGGING) end @testset "Sparse" begin test_differentiation( MyAutoSparse.(adaptive_backends), - default_scenarios(; include_constantified=true); - logging=LOGGING, + default_scenarios(; include_constantified = true); + logging = LOGGING, ) test_differentiation( @@ -89,13 +89,13 @@ end vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2])) ), sparse_scenarios(; - include_constantified=true, - include_cachified=true, - include_constantorcachified=true, - use_tuples=true, + include_constantified = true, + include_cachified = true, + include_constantorcachified = true, + use_tuples = true, ); - sparsity=true, - logging=LOGGING, + sparsity = true, + logging = LOGGING, ) @testset "Complex numbers" begin @@ -104,11 +104,11 @@ end vcat( adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2]) ); - sparsity_detector=DenseSparsityDetector(AutoSimpleFiniteDiff(); atol=1e-5), - coloring_algorithm=GreedyColoringAlgorithm(), + sparsity_detector = DenseSparsityDetector(AutoSimpleFiniteDiff(); atol = 1.0e-5), + coloring_algorithm = GreedyColoringAlgorithm(), ), complex_sparse_scenarios(); - logging=LOGGING, + logging = LOGGING, ) end @@ -152,6 +152,6 @@ end AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), ], vcat(static_scenarios(), gpu_scenarios()); - logging=LOGGING, + logging = LOGGING, ) end; diff --git a/DifferentiationInterface/test/Core/ZeroBackends/test.jl b/DifferentiationInterface/test/Core/ZeroBackends/test.jl index 396e7d455..6b84e56a4 100644 --- a/DifferentiationInterface/test/Core/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Core/ZeroBackends/test.jl @@ -25,21 +25,21 @@ end SecondOrder(AutoZeroForward(), AutoZeroReverse()), SecondOrder(AutoZeroReverse(), AutoZeroForward()), ], - default_scenarios(; include_batchified=false, include_constantified=true); - correctness=false, - type_stability=safetypestab(:full), - logging=LOGGING, + default_scenarios(; include_batchified = false, include_constantified = true); + correctness = false, + type_stability = safetypestab(:full), + logging = LOGGING, ) test_differentiation( - AutoSparse.(zero_backends, coloring_algorithm=GreedyColoringAlgorithm()), - default_scenarios(; include_constantified=true); - correctness=false, - type_stability=safetypestab(:full), - excluded=[ - :pushforward, :pullback, :gradient, :derivative, :hvp, :second_derivative + AutoSparse.(zero_backends, coloring_algorithm = GreedyColoringAlgorithm()), + default_scenarios(; include_constantified = true); + correctness = false, + type_stability = safetypestab(:full), + excluded = [ + :pushforward, :pullback, :gradient, :derivative, :hvp, :second_derivative, ], - logging=LOGGING, + logging = LOGGING, ) end @@ -47,13 +47,13 @@ end test_differentiation( [AutoZeroForward(), AutoZeroReverse()], zero.(vcat(component_scenarios(), static_scenarios(), gpu_scenarios())); - correctness=true, - logging=LOGGING, + correctness = true, + logging = LOGGING, ) end @testset "Empty arrays" begin test_differentiation( - [AutoZeroForward(), AutoZeroReverse()], empty_scenarios(); excluded=[:jacobian] + [AutoZeroForward(), AutoZeroReverse()], empty_scenarios(); excluded = [:jacobian] ) end; diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl index 804bad70e..da9a76b3a 100644 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ b/DifferentiationInterface/test/Down/Flux/test.jl @@ -18,9 +18,9 @@ test_differentiation( # AutoEnzyme(), # TODO a few scenarios fail ], DIT.flux_scenarios(Random.MersenneTwister(0)); - isapprox=DIT.flux_isapprox, - rtol=1e-2, - atol=1e-4, - scenario_intact=false, # TODO: why? - logging=LOGGING, + isapprox = DIT.flux_isapprox, + rtol = 1.0e-2, + atol = 1.0e-4, + scenario_intact = false, # TODO: why? + logging = LOGGING, ) diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl index 732665148..b1d0652ff 100644 --- a/DifferentiationInterface/test/Down/Lux/test.jl +++ b/DifferentiationInterface/test/Down/Lux/test.jl @@ -14,9 +14,9 @@ LOGGING = get(ENV, "CI", "false") == "false" test_differentiation( AutoZygote(), DIT.lux_scenarios(Random.Xoshiro(63)); - isapprox=DIT.lux_isapprox, - rtol=1.0f-2, - atol=1.0f-3, - scenario_intact=false, # TODO: why? - logging=LOGGING, + isapprox = DIT.lux_isapprox, + rtol = 1.0f-2, + atol = 1.0f-3, + scenario_intact = false, # TODO: why? + logging = LOGGING, ) diff --git a/DifferentiationInterface/test/GPU/CUDA/main.jl b/DifferentiationInterface/test/GPU/CUDA/main.jl index f7d1fabfb..ec16df87c 100644 --- a/DifferentiationInterface/test/GPU/CUDA/main.jl +++ b/DifferentiationInterface/test/GPU/CUDA/main.jl @@ -1,7 +1,7 @@ @info "Testing on CUDA" using Pkg Pkg.add("CUDA") -Pkg.develop(PackageSpec(; path="./DifferentiationInterface")) +Pkg.develop(PackageSpec(; path = "./DifferentiationInterface")) using Test @testset verbose = true "Simple" begin diff --git a/DifferentiationInterface/test/runtests.jl b/DifferentiationInterface/test/runtests.jl index e50d4839d..df7884a27 100644 --- a/DifferentiationInterface/test/runtests.jl +++ b/DifferentiationInterface/test/runtests.jl @@ -4,7 +4,7 @@ using Test DIT_PATH = joinpath(@__DIR__, "..", "..", "DifferentiationInterfaceTest") if isdir(DIT_PATH) - Pkg.develop(; path=DIT_PATH) + Pkg.develop(; path = DIT_PATH) else Pkg.add("DifferentiationInterfaceTest") end @@ -19,8 +19,8 @@ include("testutils.jl") @testset verbose = true "$category" begin @testset verbose = true "$folder" begin @testset verbose = true "$file" for file in readdir( - joinpath(@__DIR__, category, folder) - ) + joinpath(@__DIR__, category, folder) + ) endswith(file, ".jl") || continue @info "Testing $category/$folder/$file" include(joinpath(@__DIR__, category, folder, file)) @@ -33,8 +33,8 @@ include("testutils.jl") @testset verbose = true for folder in readdir(joinpath(@__DIR__, category)) isdir(joinpath(@__DIR__, category, folder)) || continue @testset verbose = true "$file" for file in readdir( - joinpath(@__DIR__, category, folder) - ) + joinpath(@__DIR__, category, folder) + ) endswith(file, ".jl") || continue @info "Testing $category/$folder/$file" include(joinpath(@__DIR__, category, folder, file)) diff --git a/DifferentiationInterface/test/testutils.jl b/DifferentiationInterface/test/testutils.jl index 8dec57bfa..5bb8bd3a8 100644 --- a/DifferentiationInterface/test/testutils.jl +++ b/DifferentiationInterface/test/testutils.jl @@ -16,8 +16,8 @@ using DifferentiationInterfaceTest: function MyAutoSparse(backend::AbstractADType) return AutoSparse( backend; - sparsity_detector=TracerSparsityDetector(), - coloring_algorithm=GreedyColoringAlgorithm(; postprocessing=true), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm(; postprocessing = true), ) end diff --git a/DifferentiationInterfaceTest/README.md b/DifferentiationInterfaceTest/README.md index 376e6d93d..86c8e12bb 100644 --- a/DifferentiationInterfaceTest/README.md +++ b/DifferentiationInterfaceTest/README.md @@ -2,13 +2,13 @@ [![Build Status](https://github.com/JuliaDiff/DifferentiationInterface.jl/actions/workflows/Test.yml/badge.svg?branch=main)](https://github.com/JuliaDiff/DifferentiationInterface.jl/actions/workflows/Test.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/JuliaDiff/DifferentiationInterface.jl/branch/main/graph/badge.svg?flag=DIT)](https://app.codecov.io/gh/JuliaDiff/DifferentiationInterface.jl) -[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/JuliaDiff/BlueStyle) +[![code style: runic](https://img.shields.io/badge/code_style-%E1%9A%B1%E1%9A%A2%E1%9A%BE%E1%9B%81%E1%9A%B2-black)](https://github.com/fredrikekre/Runic.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![DOI](https://zenodo.org/badge/740973714.svg)](https://zenodo.org/doi/10.5281/zenodo.11092033) -| Package | Docs | -|:----------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| -| DifferentiationInterface | [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/dev/) | +| Package | Docs | +| :--------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| DifferentiationInterface | [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/dev/) | | DifferentiationInterfaceTest | [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterfaceTest/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterfaceTest/dev/) | Testing and benchmarking utilities for automatic differentiation (AD) in Julia, based on [DifferentiationInterface](https://github.com/JuliaDiff/DifferentiationInterface.jl/tree/main/DifferentiationInterface). diff --git a/DifferentiationInterfaceTest/docs/make.jl b/DifferentiationInterfaceTest/docs/make.jl index 99468652c..50d36a595 100644 --- a/DifferentiationInterfaceTest/docs/make.jl +++ b/DifferentiationInterfaceTest/docs/make.jl @@ -7,14 +7,14 @@ using BenchmarkTools: BenchmarkTools using DataFrames: DataFrames using ForwardDiff: ForwardDiff -cp(joinpath(@__DIR__, "..", "README.md"), joinpath(@__DIR__, "src", "index.md"); force=true) +cp(joinpath(@__DIR__, "..", "README.md"), joinpath(@__DIR__, "src", "index.md"); force = true) makedocs(; - modules=[DifferentiationInterfaceTest], - authors="Guillaume Dalle, Adrian Hill", - sitename="DifferentiationInterfaceTest.jl", - format=Documenter.HTML(), - pages=[ + modules = [DifferentiationInterfaceTest], + authors = "Guillaume Dalle, Adrian Hill", + sitename = "DifferentiationInterfaceTest.jl", + format = Documenter.HTML(), + pages = [ "Home" => "index.md", # "Tutorial" => "tutorial.md", # "API reference" => "api.md", @@ -22,9 +22,9 @@ makedocs(; ) deploydocs(; - repo="github.com/JuliaDiff/DifferentiationInterface.jl", - devbranch="main", - dirname="DifferentiationInterfaceTest", - tag_prefix="DifferentiationInterfaceTest-", - push_preview=false, + repo = "github.com/JuliaDiff/DifferentiationInterface.jl", + devbranch = "main", + dirname = "DifferentiationInterfaceTest", + tag_prefix = "DifferentiationInterfaceTest-", + push_preview = false, ) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl index 2b432ad3a..15ccad4d7 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl @@ -10,7 +10,7 @@ function comp_to_num(x::ComponentVector)::Number return sum(sin.(x.a)) + sum(cos.(x.b)) end -comp_to_num_gradient(x) = ComponentVector(; a=cos.(x.a), b=(-sin.(x.b))) +comp_to_num_gradient(x) = ComponentVector(; a = cos.(x.a), b = (-sin.(x.b))) function comp_to_num_pushforward(x, dx) g = comp_to_num_gradient(x) @@ -33,13 +33,13 @@ function comp_to_num_scenarios_onearg(x::ComponentVector; dx::AbstractVector, dy append!( scens, [ - DIT.Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)), - DIT.Scenario{:gradient,pl_op}(f, x; res1=grad), + DIT.Scenario{:pullback, pl_op}(f, x, (dy,); res1 = (dx_from_dy,)), + DIT.Scenario{:gradient, pl_op}(f, x; res1 = grad), ], ) end for pl_op in (:out,) - append!(scens, [DIT.Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,))]) + append!(scens, [DIT.Scenario{:pushforward, pl_op}(f, x, (dx,); res1 = (dy_from_dx,))]) end return scens end @@ -49,12 +49,12 @@ end function DIT.component_scenarios() dy_ = -1 / 12 - x_comp = ComponentVector(; a=float.(1:4), b=float.(5:6)) - dx_comp = ComponentVector(; a=float.(4:-1:1), b=float.(6:-1:5)) + x_comp = ComponentVector(; a = float.(1:4), b = float.(5:6)) + dx_comp = ComponentVector(; a = float.(4:-1:1), b = float.(6:-1:5)) scens = vcat( # one argument - comp_to_num_scenarios_onearg(x_comp::ComponentVector; dx=dx_comp, dy=dy_), + comp_to_num_scenarios_onearg(x_comp::ComponentVector; dx = dx_comp, dy = dy_), # two arguments ) return scens diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index 88e3ef490..3296995ec 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -66,7 +66,7 @@ function square_loss_iterated(cell, x) return mean(abs2, y) end -struct SimpleDense{W,B,F} +struct SimpleDense{W, B, F} w::W b::B σ::F @@ -76,7 +76,7 @@ end @functor SimpleDense -function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) +function DIT.flux_scenarios(rng::AbstractRNG = default_rng()) init = glorot_uniform(rng) scens = DIT.Scenario[] @@ -91,12 +91,12 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) x = randn(rng, d_in) g = gradient_finite_differences(square_loss, model, x) - scen = DIT.Scenario{:gradient,:out}( + scen = DIT.Scenario{:gradient, :out}( square_loss, model, DI.Constant(x); - prep_args=(x=model, contexts=(DI.Constant(x),)), - res1=g, + prep_args = (x = model, contexts = (DI.Constant(x),)), + res1 = g, ) push!(scens, scen) @@ -160,18 +160,18 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) Chain(LSTM(3 => 4), RNN(4 => 5), Dense(5 => 2)), randn(rng, Float32, 3, 2, 1) ), - #! format: on + #! format: on ] for (model, x) in models_and_xs Flux.trainmode!(model) g = gradient_finite_differences(square_loss, model, x) - scen = DIT.Scenario{:gradient,:out}( + scen = DIT.Scenario{:gradient, :out}( square_loss, model, DI.Constant(x); - prep_args=(; x=model, contexts=(DI.Constant(x),)), - res1=g, + prep_args = (; x = model, contexts = (DI.Constant(x),)), + res1 = g, ) push!(scens, scen) end @@ -198,12 +198,12 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) for (model, x) in recurrent_models_and_xs Flux.trainmode!(model) g = gradient_finite_differences(square_loss_iterated, model, x) - scen = DIT.Scenario{:gradient,:out}( + scen = DIT.Scenario{:gradient, :out}( square_loss_iterated, model, DI.Constant(x); - prep_args=(; x=model, contexts=(DI.Constant(x),)), - res1=g, + prep_args = (; x = model, contexts = (DI.Constant(x),)), + res1 = g, ) push!(scens, scen) end diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl index c2b3755fc..f462895e4 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl @@ -20,20 +20,20 @@ myjl(x::AbstractArray) = jl(x) myjl(x::Tuple) = map(myjl, x) myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x))) myjl(x::DI.Cache{<:AbstractArray}) = DI.Cache(myjl(DI.unwrap(x))) -myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap(x))) +myjl(x::DI.Cache{<:Union{Tuple, NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap(x))) myjl(::Nothing) = nothing -function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} +function myjl(scen::DIT.Scenario{op, pl_op, pl_fun}) where {op, pl_op, pl_fun} (; f, x, y, t, contexts, prep_args, res1, res2, name) = scen - return DIT.Scenario{op,pl_op,pl_fun}(; - f=myjl(f), - x=myjl(x), - y=myjl(y), - t=myjl(t), - contexts=myjl(contexts), - prep_args=map(myjl, prep_args), - res1=myjl(res1), - res2=myjl(res2), + return DIT.Scenario{op, pl_op, pl_fun}(; + f = myjl(f), + x = myjl(x), + y = myjl(y), + t = myjl(t), + contexts = myjl(contexts), + prep_args = map(myjl, prep_args), + res1 = myjl(res1), + res2 = myjl(res2), name, ) end @@ -44,7 +44,7 @@ function DIT.gpu_scenarios(args...; kwargs...) end @compile_workload begin - DIT.gpu_scenarios(; include_constantified=true, include_cachified=true) + DIT.gpu_scenarios(; include_constantified = true, include_cachified = true) end end diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl index 8119b073f..4dddb944d 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -44,7 +44,7 @@ function square_loss(ps, model, x, st) return sum(abs2, first(model(x, ps, st))) end -function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) +function DIT.lux_scenarios(rng::AbstractRNG = default_rng()) models_and_xs = [ #! format: off ( @@ -197,17 +197,17 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) g = DI.gradient( ps -> square_loss(ps, model, x, st), DI.AutoForwardDiff(), ComponentArray(ps) ) - scen = DIT.Scenario{:gradient,:out}( + scen = DIT.Scenario{:gradient, :out}( square_loss, ComponentArray(ps), DI.Constant(model), DI.Constant(x), DI.Constant(st); - prep_args=( - x=ComponentArray(ps), - contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)), + prep_args = ( + x = ComponentArray(ps), + contexts = (DI.Constant(model), DI.Constant(x), DI.Constant(st)), ), - res1=g, + res1 = g, ) push!(scens, scen) end diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index c1f28c8ec..f5737f005 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -18,42 +18,42 @@ mystatic(f::DIT.FunctionModifier) = f mystatic(x::Number) = x -mystatic(x::AbstractVector{T}) where {T} = convert(SVector{length(x),T}, x) -mymutablestatic(x::AbstractVector{T}) where {T} = convert(MVector{length(x),T}, x) +mystatic(x::AbstractVector{T}) where {T} = convert(SVector{length(x), T}, x) +mymutablestatic(x::AbstractVector{T}) where {T} = convert(MVector{length(x), T}, x) function mystatic(x::AbstractMatrix{T}) where {T} - return convert(SMatrix{size(x, 1),size(x, 2),T,length(x)}, x) + return convert(SMatrix{size(x, 1), size(x, 2), T, length(x)}, x) end function mymutablestatic(x::AbstractMatrix{T}) where {T} - return convert(MMatrix{size(x, 1),size(x, 2),T,length(x)}, x) + return convert(MMatrix{size(x, 1), size(x, 2), T, length(x)}, x) end mystatic(x::Tuple) = map(mystatic, x) mystatic(x::DI.Constant) = DI.Constant(mystatic(DI.unwrap(x))) mystatic(x::DI.Cache{<:AbstractArray}) = DI.Cache(mymutablestatic(DI.unwrap(x))) -function mystatic(x::DI.Cache{<:Union{Tuple,NamedTuple}}) +function mystatic(x::DI.Cache{<:Union{Tuple, NamedTuple}}) return map(mystatic, map(DI.Cache, DI.unwrap(x))) end mystatic(::Nothing) = nothing -function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} +function mystatic(scen::DIT.Scenario{op, pl_op, pl_fun}) where {op, pl_op, pl_fun} (; f, x, y, t, contexts, prep_args, res1, res2, name) = scen new_prep_args = (; - x=mystatic(prep_args.x), contexts=map(mystatic, prep_args.contexts), t=mystatic(t) + x = mystatic(prep_args.x), contexts = map(mystatic, prep_args.contexts), t = mystatic(t), ) if pl_fun == :in - new_prep_args = (; new_prep_args..., y=mymutablestatic(prep_args.y)) + new_prep_args = (; new_prep_args..., y = mymutablestatic(prep_args.y)) end - return DIT.Scenario{op,pl_op,pl_fun}(; - f=mystatic(f), - x=mystatic(x), - y=pl_fun == :in ? mymutablestatic(y) : mystatic(y), - t=mystatic(t), - contexts=mystatic(contexts), - prep_args=new_prep_args, - res1=mystatic(res1), - res2=mystatic(res2), - name=name, + return DIT.Scenario{op, pl_op, pl_fun}(; + f = mystatic(f), + x = mystatic(x), + y = pl_fun == :in ? mymutablestatic(y) : mystatic(y), + t = mystatic(t), + contexts = mystatic(contexts), + prep_args = new_prep_args, + res1 = mystatic(res1), + res2 = mystatic(res2), + name = name, ) end @@ -63,7 +63,7 @@ function DIT.static_scenarios(args...; kwargs...) end @compile_workload begin - DIT.static_scenarios(; include_constantified=true, include_cachified=true) + DIT.static_scenarios(; include_constantified = true, include_cachified = true) end end diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index bfe491ef9..51b216915 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -152,7 +152,7 @@ export test_differentiation, benchmark_differentiation export DifferentiationBenchmarkDataRow @compile_workload begin - default_scenarios(; include_constantified=true, include_cachified=true) + default_scenarios(; include_constantified = true, include_cachified = true) end end diff --git a/DifferentiationInterfaceTest/src/scenarios/allocfree.jl b/DifferentiationInterfaceTest/src/scenarios/allocfree.jl index d6f089195..7abd1b90a 100644 --- a/DifferentiationInterfaceTest/src/scenarios/allocfree.jl +++ b/DifferentiationInterfaceTest/src/scenarios/allocfree.jl @@ -5,9 +5,9 @@ function identity_scenarios(x::Number; dx::Number, dy::Number) der = oneunit(x) return [ - Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)), - Scenario{:pullback,:out}(f, x, (dy,); res1=(dx_from_dy,)), - Scenario{:derivative,:out}(f, x; res1=der), + Scenario{:pushforward, :out}(f, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:pullback, :out}(f, x, (dy,); res1 = (dx_from_dy,)), + Scenario{:derivative, :out}(f, x; res1 = der), ] end @@ -19,9 +19,9 @@ function sum_scenarios(x::AbstractArray; dx::AbstractArray, dy::Number) grad .= oneunit(eltype(x)) return [ - Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)), - Scenario{:pullback,:in}(f, x, (dy,); res1=(dx_from_dy,)), - Scenario{:gradient,:in}(f, x; res1=grad), + Scenario{:pushforward, :out}(f, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:pullback, :in}(f, x, (dy,); res1 = (dx_from_dy,)), + Scenario{:gradient, :in}(f, x; res1 = grad), ] end @@ -34,9 +34,9 @@ function copyto!_scenarios(x::AbstractArray; dx::AbstractArray, dy::AbstractArra jac = Matrix(Diagonal(ones(eltype(x), length(x)))) return [ - Scenario{:pushforward,:in}(f!, y, x, (dx,); res1=(dy_from_dx,)), - Scenario{:pullback,:in}(f!, y, x, (dy,); res1=(dx_from_dy,)), - Scenario{:jacobian,:in}(f!, y, x; res1=jac), + Scenario{:pushforward, :in}(f!, y, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:pullback, :in}(f!, y, x, (dy,); res1 = (dx_from_dy,)), + Scenario{:jacobian, :in}(f!, y, x; res1 = jac), ] end @@ -59,9 +59,9 @@ function allocfree_scenarios() dy_6 = float.(-5:2:5) scens = vcat( - identity_scenarios(x_; dx=dx_, dy=dy_), # - sum_scenarios(x_6; dx=dx_6, dy=dy_), - copyto!_scenarios(x_6; dx=dx_6, dy=dy_6), + identity_scenarios(x_; dx = dx_, dy = dy_), # + sum_scenarios(x_6; dx = dx_6, dy = dy_), + copyto!_scenarios(x_6; dx = dx_6, dy = dy_6), ) return scens end diff --git a/DifferentiationInterfaceTest/src/scenarios/complex.jl b/DifferentiationInterfaceTest/src/scenarios/complex.jl index 39980d521..5437e0fec 100644 --- a/DifferentiationInterfaceTest/src/scenarios/complex.jl +++ b/DifferentiationInterfaceTest/src/scenarios/complex.jl @@ -7,10 +7,10 @@ function complex_holomorphic_gradient_scenarios() x = [1.0 + im] grad = 2 * conj(x) scens = Scenario[ - Scenario{:gradient,:out}(square_only, x; res1=grad), - Scenario{:gradient,:in}(square_only, x; res1=grad), - Scenario{:pullback,:out}(square_only, x, (dy,); res1=(grad,)), - Scenario{:pullback,:in}(square_only, x, (dy,); res1=(grad,)), + Scenario{:gradient, :out}(square_only, x; res1 = grad), + Scenario{:gradient, :in}(square_only, x; res1 = grad), + Scenario{:pullback, :out}(square_only, x, (dy,); res1 = (grad,)), + Scenario{:pullback, :in}(square_only, x, (dy,); res1 = (grad,)), ] return scens end @@ -20,10 +20,10 @@ function complex_gradient_scenarios() x = [1.0 + im] grad = 2 * x scens = Scenario[ - Scenario{:gradient,:out}(abs2_only, x; res1=grad), - Scenario{:gradient,:in}(abs2_only, x; res1=grad), - Scenario{:pullback,:out}(abs2_only, x, (dy,); res1=(grad,)), - Scenario{:pullback,:in}(abs2_only, x, (dy,); res1=(grad,)), + Scenario{:gradient, :out}(abs2_only, x; res1 = grad), + Scenario{:gradient, :in}(abs2_only, x; res1 = grad), + Scenario{:pullback, :out}(abs2_only, x, (dy,); res1 = (grad,)), + Scenario{:pullback, :in}(abs2_only, x, (dy,); res1 = (grad,)), ] return scens end @@ -47,13 +47,13 @@ function complex_scenarios() scens = vcat( # one argument - num_to_num_scenarios(x_; dx=dx_, dy=dy_), - num_to_vec_scenarios_onearg(x_; dx=dx_, dy=dy_2), - arr_to_num_scenarios_onearg(x_6; dx=dx_6, dy=dy_), - vec_to_vec_scenarios_onearg(x_6; dx=dx_6, dy=dy_12), + num_to_num_scenarios(x_; dx = dx_, dy = dy_), + num_to_vec_scenarios_onearg(x_; dx = dx_, dy = dy_2), + arr_to_num_scenarios_onearg(x_6; dx = dx_6, dy = dy_), + vec_to_vec_scenarios_onearg(x_6; dx = dx_6, dy = dy_12), # two arguments - num_to_vec_scenarios_twoarg(x_; dx=dx_, dy=dy_6), - vec_to_vec_scenarios_twoarg(x_6; dx=dx_6, dy=dy_12), + num_to_vec_scenarios_twoarg(x_; dx = dx_, dy = dy_6), + vec_to_vec_scenarios_twoarg(x_6; dx = dx_6, dy = dy_12), # complex gradients complex_gradient_scenarios(), complex_holomorphic_gradient_scenarios(), diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index a5889a379..c8229c1c4 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -27,10 +27,10 @@ function num_to_num_scenarios(x::Number; dx::Number, dy::Number) # everyone out of place scens = Scenario[ - Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)), - Scenario{:pullback,:out}(f, x, (dy,); res1=(dx_from_dy,)), - Scenario{:derivative,:out}(f, x; res1=der), - Scenario{:second_derivative,:out}(f, x; res1=der, res2=der2), + Scenario{:pushforward, :out}(f, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:pullback, :out}(f, x, (dy,); res1 = (dx_from_dy,)), + Scenario{:derivative, :out}(f, x; res1 = der), + Scenario{:second_derivative, :out}(f, x; res1 = der, res2 = der2), ] return scens end @@ -56,13 +56,13 @@ function onevec_to_onevec_scenarios_onearg(x::Number; dx::Number, dy::Number) append!( scens, [ - Scenario{:pushforward,pl_op}( - onevec_to_onevec, [x], ([dx],); res1=([dy_from_dx],) + Scenario{:pushforward, pl_op}( + onevec_to_onevec, [x], ([dx],); res1 = ([dy_from_dx],) ), - Scenario{:pullback,pl_op}( - onevec_to_onevec, [x], ([dy],); res1=([dx_from_dy],) + Scenario{:pullback, pl_op}( + onevec_to_onevec, [x], ([dy],); res1 = ([dx_from_dy],) ), - Scenario{:jacobian,pl_op}(onevec_to_onevec, [x]; res1=jac), + Scenario{:jacobian, pl_op}(onevec_to_onevec, [x]; res1 = jac), ], ) end @@ -84,13 +84,13 @@ function onevec_to_onevec_scenarios_twoarg(x::Number; dx::Number, dy::Number) append!( scens, [ - Scenario{:pushforward,pl_op}( - onevec_to_onevec!, [y], [x], ([dx],); res1=([dy_from_dx],) + Scenario{:pushforward, pl_op}( + onevec_to_onevec!, [y], [x], ([dx],); res1 = ([dy_from_dx],) ), - Scenario{:pullback,pl_op}( - onevec_to_onevec!, [y], [x], ([dy],); res1=([dx_from_dy],) + Scenario{:pullback, pl_op}( + onevec_to_onevec!, [y], [x], ([dy],); res1 = ([dx_from_dy],) ), - Scenario{:jacobian,pl_op}(onevec_to_onevec!, [y], [x]; res1=jac), + Scenario{:jacobian, pl_op}(onevec_to_onevec!, [y], [x]; res1 = jac), ], ) end @@ -121,7 +121,7 @@ end num_to_vec!_pushforward(x, dx; y) = dx .* num_to_vec!_derivative(x; y) function num_to_vec!_pullback(x, dy) - return sum(conj(num_to_vec!_derivative(x; y=similar(dy))) .* dy) + return sum(conj(num_to_vec!_derivative(x; y = similar(dy))) .* dy) end function num_to_vec_scenarios_onearg(x::Number; dx::Number, dy::AbstractArray) @@ -137,14 +137,14 @@ function num_to_vec_scenarios_onearg(x::Number; dx::Number, dy::AbstractArray) append!( scens, [ - Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,)), - Scenario{:derivative,pl_op}(f, x; res1=der), - Scenario{:second_derivative,pl_op}(f, x; res1=der, res2=der2), + Scenario{:pushforward, pl_op}(f, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:derivative, pl_op}(f, x; res1 = der), + Scenario{:second_derivative, pl_op}(f, x; res1 = der, res2 = der2), ], ) end for pl_op in (:out,) - append!(scens, [Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,))]) + append!(scens, [Scenario{:pullback, pl_op}(f, x, (dy,); res1 = (dx_from_dy,))]) end return scens end @@ -163,13 +163,13 @@ function num_to_vec_scenarios_twoarg(x::Number; dx::Number, dy::AbstractArray) append!( scens, [ - Scenario{:pushforward,pl_op}(f!, y, x, (dx,); res1=(dy_from_dx,)), - Scenario{:derivative,pl_op}(f!, y, x; res1=der), + Scenario{:pushforward, pl_op}(f!, y, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:derivative, pl_op}(f!, y, x; res1 = der), ], ) end for pl_op in (:out,) - append!(scens, [Scenario{:pullback,pl_op}(f!, y, x, (dy,); res1=(dx_from_dy,))]) + append!(scens, [Scenario{:pullback, pl_op}(f!, y, x, (dy,); res1 = (dx_from_dy,))]) end return scens end @@ -197,13 +197,13 @@ end function num_to_mat!_derivative(x; y) return hcat( - num_to_vec!_derivative(x; y=y[:, 1]), 3 .* num_to_vec!_derivative(3x; y=y[:, 2]) + num_to_vec!_derivative(x; y = y[:, 1]), 3 .* num_to_vec!_derivative(3x; y = y[:, 2]) ) end function num_to_mat!_pushforward(x, dx; y) return hcat( - num_to_vec!_pushforward(x, dx; y=y[:, 1]), - 3 .* num_to_vec!_pushforward(3x, dx; y=y[:, 2]), + num_to_vec!_pushforward(x, dx; y = y[:, 1]), + 3 .* num_to_vec!_pushforward(3x, dx; y = y[:, 2]), ) end function num_to_mat!_pullback(x, dy) @@ -223,14 +223,14 @@ function num_to_mat_scenarios_onearg(x::Number; dx::Number, dy::AbstractArray) append!( scens, [ - Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,)), - Scenario{:derivative,pl_op}(f, x; res1=der), - Scenario{:second_derivative,pl_op}(f, x; res1=der, res2=der2), + Scenario{:pushforward, pl_op}(f, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:derivative, pl_op}(f, x; res1 = der), + Scenario{:second_derivative, pl_op}(f, x; res1 = der, res2 = der2), ], ) end for pl_op in (:out,) - append!(scens, [Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,))]) + append!(scens, [Scenario{:pullback, pl_op}(f, x, (dy,); res1 = (dx_from_dy,))]) end return scens end @@ -249,13 +249,13 @@ function num_to_mat_scenarios_twoarg(x::Number; dx::Number, dy::AbstractArray) append!( scens, [ - Scenario{:pushforward,pl_op}(f!, y, x, (dx,); res1=(dy_from_dx,)), - Scenario{:derivative,pl_op}(f!, y, x; res1=der), + Scenario{:pushforward, pl_op}(f!, y, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:derivative, pl_op}(f!, y, x; res1 = der), ], ) end for pl_op in (:out,) - append!(scens, [Scenario{:pullback,pl_op}(f!, y, x, (dy,); res1=(dx_from_dy,))]) + append!(scens, [Scenario{:pullback, pl_op}(f!, y, x, (dy,); res1 = (dx_from_dy,))]) end return scens end @@ -282,8 +282,8 @@ function arr_to_num_gradient(x0) for k in eachindex(g, x) g[k] = ( α * x[k]^(α - 1) * sum(x[j]^β for j in eachindex(x) if j != k) + - β * x[k]^(β - 1) * sum(x[i]^α for i in eachindex(x) if i != k) + - (α + β) * x[k]^(α + β - 1) + β * x[k]^(β - 1) * sum(x[i]^α for i in eachindex(x) if i != k) + + (α + β) * x[k]^(α + β - 1) ) end return conj(convert(typeof(x0), g)) @@ -296,8 +296,8 @@ function arr_to_num_hessian(x0) if k == l H[k, k] = ( α * (α - 1) * x[k]^(α - 2) * sum(x[j]^β for j in eachindex(x) if j != k) + - β * (β - 1) * x[k]^(β - 2) * sum(x[i]^α for i in eachindex(x) if i != k) + - (α + β) * (α + β - 1) * x[k]^(α + β - 2) + β * (β - 1) * x[k]^(β - 2) * sum(x[i]^α for i in eachindex(x) if i != k) + + (α + β) * (α + β - 1) * x[k]^(α + β - 2) ) else H[k, l] = α * β * (x[k]^(α - 1) * x[l]^(β - 1) + x[k]^(β - 1) * x[l]^(α - 1)) @@ -311,8 +311,8 @@ arr_to_num_pullback(x, dy) = arr_to_num_gradient(x) .* dy arr_to_num_hvp(x, dx) = reshape(arr_to_num_hessian(x) * vec(dx), size(x)) function arr_to_num_scenarios_onearg( - x::AbstractArray; dx::AbstractArray, dy::Number, linalg=true -) + x::AbstractArray; dx::AbstractArray, dy::Number, linalg = true + ) f = linalg ? arr_to_num_linalg : arr_to_num_no_linalg dy_from_dx = arr_to_num_pushforward(x, dx) dx_from_dy = arr_to_num_pullback(x, dy) @@ -326,15 +326,15 @@ function arr_to_num_scenarios_onearg( append!( scens, [ - Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)), - Scenario{:gradient,pl_op}(f, x; res1=grad), - Scenario{:hvp,pl_op}(f, x, (dx,); res1=grad, res2=(dg,)), - Scenario{:hessian,pl_op}(f, x; res1=grad, res2=hess), + Scenario{:pullback, pl_op}(f, x, (dy,); res1 = (dx_from_dy,)), + Scenario{:gradient, pl_op}(f, x; res1 = grad), + Scenario{:hvp, pl_op}(f, x, (dx,); res1 = grad, res2 = (dg,)), + Scenario{:hessian, pl_op}(f, x; res1 = grad, res2 = hess), ], ) end for pl_op in (:out,) - append!(scens, [Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,))]) + append!(scens, [Scenario{:pushforward, pl_op}(f, x, (dx,); res1 = (dy_from_dx,))]) end return scens end @@ -347,9 +347,9 @@ function all_array_to_array_scenarios(f, x; dx, dy, dy_from_dx, dx_from_dy, jac) append!( scens, [ - Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,)), - Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)), - Scenario{:jacobian,pl_op}(f, x; res1=jac), + Scenario{:pushforward, pl_op}(f, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:pullback, pl_op}(f, x, (dy,); res1 = (dx_from_dy,)), + Scenario{:jacobian, pl_op}(f, x; res1 = jac), ], ) end @@ -362,9 +362,9 @@ function all_array_to_array_scenarios(f!, y, x; dx, dy, dy_from_dx, dx_from_dy, append!( scens, [ - Scenario{:pushforward,pl_op}(f!, y, x, (dx,); res1=(dy_from_dx,)), - Scenario{:pullback,pl_op}(f!, y, x, (dy,); res1=(dx_from_dy,)), - Scenario{:jacobian,pl_op}(f!, y, x; res1=jac), + Scenario{:pushforward, pl_op}(f!, y, x, (dx,); res1 = (dy_from_dx,)), + Scenario{:pullback, pl_op}(f!, y, x, (dy,); res1 = (dx_from_dy,)), + Scenario{:jacobian, pl_op}(f!, y, x; res1 = jac), ], ) end @@ -388,8 +388,8 @@ end vec_to_vec_jacobian(x) = vcat(Diagonal(cos.(x)), Diagonal(-sin.(x))) function vec_to_vec_scenarios_onearg( - x::AbstractVector; dx::AbstractVector, dy::AbstractVector -) + x::AbstractVector; dx::AbstractVector, dy::AbstractVector + ) f = vec_to_vec dy_from_dx = vec_to_vec_pushforward(x, dx) dx_from_dy = vec_to_vec_pullback(x, dy) @@ -399,8 +399,8 @@ function vec_to_vec_scenarios_onearg( end function vec_to_vec_scenarios_twoarg( - x::AbstractVector; dx::AbstractVector, dy::AbstractVector -) + x::AbstractVector; dx::AbstractVector, dy::AbstractVector + ) f! = vec_to_vec! y = similar(vec_to_vec(x)) f!(y, x) @@ -426,8 +426,8 @@ vec_to_mat_pullback(x, dy) = conj(cos.(x)) .* dy[:, 1] .- conj(sin.(x)) .* dy[:, vec_to_mat_jacobian(x) = vcat(Diagonal(cos.(x)), Diagonal(-sin.(x))) function vec_to_mat_scenarios_onearg( - x::AbstractVector; dx::AbstractVector, dy::AbstractMatrix -) + x::AbstractVector; dx::AbstractVector, dy::AbstractMatrix + ) f = vec_to_mat dy_from_dx = vec_to_mat_pushforward(x, dx) dx_from_dy = vec_to_mat_pullback(x, dy) @@ -437,8 +437,8 @@ function vec_to_mat_scenarios_onearg( end function vec_to_mat_scenarios_twoarg( - x::AbstractVector; dx::AbstractVector, dy::AbstractMatrix -) + x::AbstractVector; dx::AbstractVector, dy::AbstractMatrix + ) f! = vec_to_mat! y = similar(vec_to_mat(x)) f!(y, x) @@ -466,14 +466,14 @@ end function mat_to_vec_pullback(x, dy) return conj(cos.(x)) .* reshape(first_half(dy), size(x)) .- - conj(sin.(x)) .* reshape(second_half(dy), size(x)) + conj(sin.(x)) .* reshape(second_half(dy), size(x)) end mat_to_vec_jacobian(x) = vcat(Diagonal(vec(cos.(x))), Diagonal(vec(-sin.(x)))) function mat_to_vec_scenarios_onearg( - x::AbstractMatrix; dx::AbstractMatrix, dy::AbstractVector -) + x::AbstractMatrix; dx::AbstractMatrix, dy::AbstractVector + ) f = mat_to_vec dy_from_dx = mat_to_vec_pushforward(x, dx) dx_from_dy = mat_to_vec_pullback(x, dy) @@ -483,8 +483,8 @@ function mat_to_vec_scenarios_onearg( end function mat_to_vec_scenarios_twoarg( - x::AbstractMatrix; dx::AbstractMatrix, dy::AbstractVector -) + x::AbstractMatrix; dx::AbstractMatrix, dy::AbstractVector + ) f! = mat_to_vec! y = similar(mat_to_vec(x)) f!(y, x) @@ -512,14 +512,14 @@ end function mat_to_mat_pullback(x, dy) return conj(cos.(x)) .* reshape(dy[:, 1], size(x)) .- - conj(sin.(x)) .* reshape(dy[:, 2], size(x)) + conj(sin.(x)) .* reshape(dy[:, 2], size(x)) end mat_to_mat_jacobian(x) = vcat(Diagonal(vec(cos.(x))), Diagonal(vec(-sin.(x)))) function mat_to_mat_scenarios_onearg( - x::AbstractMatrix; dx::AbstractMatrix, dy::AbstractMatrix -) + x::AbstractMatrix; dx::AbstractMatrix, dy::AbstractMatrix + ) f = mat_to_mat dy_from_dx = mat_to_mat_pushforward(x, dx) dx_from_dy = mat_to_mat_pullback(x, dy) @@ -529,8 +529,8 @@ function mat_to_mat_scenarios_onearg( end function mat_to_mat_scenarios_twoarg( - x::AbstractMatrix; dx::AbstractMatrix, dy::AbstractMatrix -) + x::AbstractMatrix; dx::AbstractMatrix, dy::AbstractMatrix + ) f! = mat_to_mat! y = similar(mat_to_mat(x)) f!(y, x) @@ -549,16 +549,16 @@ end Create a vector of [`Scenario`](@ref)s with standard array types. """ function default_scenarios(; - linalg=true, - include_normal=true, - include_batchified=true, - include_closurified=false, - include_constantified=false, - include_cachified=false, - include_constantorcachified=false, - use_tuples=false, - include_smaller=false, -) + linalg = true, + include_normal = true, + include_batchified = true, + include_closurified = false, + include_constantified = false, + include_cachified = false, + include_constantorcachified = false, + use_tuples = false, + include_smaller = false, + ) x_ = 0.42 dx_ = 3.14 dy_ = -1 / 12 @@ -578,63 +578,63 @@ function default_scenarios(; scens = vcat( # one argument - num_to_num_scenarios(x_; dx=dx_, dy=dy_), - onevec_to_onevec_scenarios_onearg(x_; dx=dx_, dy=dy_), - num_to_vec_scenarios_onearg(x_; dx=dx_, dy=dy_2), - num_to_mat_scenarios_onearg(x_; dx=dx_, dy=dy_2_2), - arr_to_num_scenarios_onearg(x_6; dx=dx_6, dy=dy_, linalg), - arr_to_num_scenarios_onearg(x_2_3; dx=dx_2_3, dy=dy_, linalg), - vec_to_vec_scenarios_onearg(x_6; dx=dx_6, dy=dy_12), - vec_to_mat_scenarios_onearg(x_6; dx=dx_6, dy=dy_6_2), - mat_to_vec_scenarios_onearg(x_2_3; dx=dx_2_3, dy=dy_12), - mat_to_mat_scenarios_onearg(x_2_3; dx=dx_2_3, dy=dy_6_2), + num_to_num_scenarios(x_; dx = dx_, dy = dy_), + onevec_to_onevec_scenarios_onearg(x_; dx = dx_, dy = dy_), + num_to_vec_scenarios_onearg(x_; dx = dx_, dy = dy_2), + num_to_mat_scenarios_onearg(x_; dx = dx_, dy = dy_2_2), + arr_to_num_scenarios_onearg(x_6; dx = dx_6, dy = dy_, linalg), + arr_to_num_scenarios_onearg(x_2_3; dx = dx_2_3, dy = dy_, linalg), + vec_to_vec_scenarios_onearg(x_6; dx = dx_6, dy = dy_12), + vec_to_mat_scenarios_onearg(x_6; dx = dx_6, dy = dy_6_2), + mat_to_vec_scenarios_onearg(x_2_3; dx = dx_2_3, dy = dy_12), + mat_to_mat_scenarios_onearg(x_2_3; dx = dx_2_3, dy = dy_6_2), # two arguments - onevec_to_onevec_scenarios_twoarg(x_; dx=dx_, dy=dy_), - num_to_vec_scenarios_twoarg(x_; dx=dx_, dy=dy_6), - num_to_mat_scenarios_twoarg(x_; dx=dx_, dy=dy_6_2), - vec_to_vec_scenarios_twoarg(x_6; dx=dx_6, dy=dy_12), - vec_to_mat_scenarios_twoarg(x_6; dx=dx_6, dy=dy_6_2), - mat_to_vec_scenarios_twoarg(x_2_3; dx=dx_2_3, dy=dy_12), - mat_to_mat_scenarios_twoarg(x_2_3; dx=dx_2_3, dy=dy_6_2), + onevec_to_onevec_scenarios_twoarg(x_; dx = dx_, dy = dy_), + num_to_vec_scenarios_twoarg(x_; dx = dx_, dy = dy_6), + num_to_mat_scenarios_twoarg(x_; dx = dx_, dy = dy_6_2), + vec_to_vec_scenarios_twoarg(x_6; dx = dx_6, dy = dy_12), + vec_to_mat_scenarios_twoarg(x_6; dx = dx_6, dy = dy_6_2), + mat_to_vec_scenarios_twoarg(x_2_3; dx = dx_2_3, dy = dy_12), + mat_to_mat_scenarios_twoarg(x_2_3; dx = dx_2_3, dy = dy_6_2), ) smallerscens = vcat( # one argument - num_to_num_scenarios(x_; dx=dx_, dy=dy_), - onevec_to_onevec_scenarios_onearg(x_; dx=dx_, dy=dy_), - num_to_vec_scenarios_onearg(x_; dx=dx_, dy=dy_2), - num_to_mat_scenarios_onearg(x_; dx=dx_, dy=dy_2_2), - arr_to_num_scenarios_onearg(x_6[1:3]; dx=dx_6[1:3], dy=dy_, linalg), - arr_to_num_scenarios_onearg(x_2_3[1:1, 1:2]; dx=dx_2_3[1:1, 1:2], dy=dy_, linalg), - vec_to_vec_scenarios_onearg(x_6[1:3]; dx=dx_6[1:3], dy=dy_12[1:6]), - vec_to_mat_scenarios_onearg(x_6[1:3]; dx=dx_6[1:3], dy=dy_6_2[1:3, :]), - mat_to_vec_scenarios_onearg(x_2_3[1:1, 1:2]; dx=dx_2_3[1:1, 1:2], dy=dy_12[1:4]), + num_to_num_scenarios(x_; dx = dx_, dy = dy_), + onevec_to_onevec_scenarios_onearg(x_; dx = dx_, dy = dy_), + num_to_vec_scenarios_onearg(x_; dx = dx_, dy = dy_2), + num_to_mat_scenarios_onearg(x_; dx = dx_, dy = dy_2_2), + arr_to_num_scenarios_onearg(x_6[1:3]; dx = dx_6[1:3], dy = dy_, linalg), + arr_to_num_scenarios_onearg(x_2_3[1:1, 1:2]; dx = dx_2_3[1:1, 1:2], dy = dy_, linalg), + vec_to_vec_scenarios_onearg(x_6[1:3]; dx = dx_6[1:3], dy = dy_12[1:6]), + vec_to_mat_scenarios_onearg(x_6[1:3]; dx = dx_6[1:3], dy = dy_6_2[1:3, :]), + mat_to_vec_scenarios_onearg(x_2_3[1:1, 1:2]; dx = dx_2_3[1:1, 1:2], dy = dy_12[1:4]), mat_to_mat_scenarios_onearg( - x_2_3[1:1, 1:2]; dx=dx_2_3[1:1, 1:2], dy=dy_6_2[1:2, :] + x_2_3[1:1, 1:2]; dx = dx_2_3[1:1, 1:2], dy = dy_6_2[1:2, :] ), # two arguments - onevec_to_onevec_scenarios_twoarg(x_; dx=dx_, dy=dy_), - num_to_vec_scenarios_twoarg(x_; dx=dx_, dy=dy_6[1:3]), - num_to_mat_scenarios_twoarg(x_; dx=dx_, dy=dy_6_2[1:3, :]), - vec_to_vec_scenarios_twoarg(x_6[1:3]; dx=dx_6[1:3], dy=dy_12[1:6]), - vec_to_mat_scenarios_twoarg(x_6[1:3]; dx=dx_6[1:3], dy=dy_6_2[1:3, :]), - mat_to_vec_scenarios_twoarg(x_2_3[1:1, 1:2]; dx=dx_2_3[1:1, 1:2], dy=dy_12[1:4]), + onevec_to_onevec_scenarios_twoarg(x_; dx = dx_, dy = dy_), + num_to_vec_scenarios_twoarg(x_; dx = dx_, dy = dy_6[1:3]), + num_to_mat_scenarios_twoarg(x_; dx = dx_, dy = dy_6_2[1:3, :]), + vec_to_vec_scenarios_twoarg(x_6[1:3]; dx = dx_6[1:3], dy = dy_12[1:6]), + vec_to_mat_scenarios_twoarg(x_6[1:3]; dx = dx_6[1:3], dy = dy_6_2[1:3, :]), + mat_to_vec_scenarios_twoarg(x_2_3[1:1, 1:2]; dx = dx_2_3[1:1, 1:2], dy = dy_12[1:4]), mat_to_mat_scenarios_twoarg( - x_2_3[1:1, 1:2]; dx=dx_2_3[1:1, 1:2], dy=dy_6_2[1:2, :] + x_2_3[1:1, 1:2]; dx = dx_2_3[1:1, 1:2], dy = dy_6_2[1:2, :] ), ) scens_smaller_prep = map(scens, smallerscens) do s1, s2 - Scenario{operator(s1),operator_place(s1),function_place(s1)}(; - f=s1.f, - y=s1.y, - x=s1.x, - t=s1.t, - contexts=s1.contexts, - res1=s1.res1, - res2=s1.res2, - name=isnothing(s1.name) ? nothing : s1.name * " [smaller prep]", - prep_args=s2.prep_args, + Scenario{operator(s1), operator_place(s1), function_place(s1)}(; + f = s1.f, + y = s1.y, + x = s1.x, + t = s1.t, + contexts = s1.contexts, + res1 = s1.res1, + res2 = s1.res2, + name = isnothing(s1.name) ? nothing : s1.name * " [smaller prep]", + prep_args = s2.prep_args, ) end @@ -644,7 +644,7 @@ function default_scenarios(; include_normal && append!(final_scens, scens) include_closurified && append!(final_scens, closurify(scens)) include_constantified && append!(final_scens, constantify(scens)) - include_cachified && append!(final_scens, cachify(scens; use_tuples=use_tuples)) + include_cachified && append!(final_scens, cachify(scens; use_tuples = use_tuples)) include_constantorcachified && append!(final_scens, constantorcachify(scens)) include_smaller && append!(final_scens, scens_smaller_prep) diff --git a/DifferentiationInterfaceTest/src/scenarios/empty.jl b/DifferentiationInterfaceTest/src/scenarios/empty.jl index f421abc0a..d699f6728 100644 --- a/DifferentiationInterfaceTest/src/scenarios/empty.jl +++ b/DifferentiationInterfaceTest/src/scenarios/empty.jl @@ -6,11 +6,11 @@ end function empty_scenarios() scens = Scenario[ - Scenario{:derivative,:out}(make_empty, 1.0; res1=Float64[]), - Scenario{:derivative,:out}(make_empty!, Float64[], 1.0; res1=Float64[]), - Scenario{:gradient,:out}(sum, Float64[]; res1=Float64[]), - Scenario{:jacobian,:out}(copy, Float64[]; res1=float.(I(0))), - Scenario{:jacobian,:out}(copyto!, Float64[], Float64[]; res1=float.(I(0))), + Scenario{:derivative, :out}(make_empty, 1.0; res1 = Float64[]), + Scenario{:derivative, :out}(make_empty!, Float64[], 1.0; res1 = Float64[]), + Scenario{:gradient, :out}(sum, Float64[]; res1 = Float64[]), + Scenario{:jacobian, :out}(copy, Float64[]; res1 = float.(I(0))), + Scenario{:jacobian, :out}(copyto!, Float64[], Float64[]; res1 = float.(I(0))), ] return scens end diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 76429253c..5761eb346 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -5,7 +5,7 @@ abstract type FunctionModifier end Return a new `Scenario` identical to `scen` except for the first- and second-order results which are set to zero. """ -function Base.zero(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} +function Base.zero(scen::Scenario{op, pl_op, pl_fun}) where {op, pl_op, pl_fun} zero_res1 = if op in (:pushforward, :pullback) map(zero, scen.res1) else @@ -18,16 +18,16 @@ function Base.zero(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} else zero(scen.res2) end - return Scenario{op,pl_op,pl_fun}(; - f=scen.f, - x=scen.x, - y=scen.y, - t=scen.t, - contexts=scen.contexts, - res1=zero_res1, - res2=zero_res2, - prep_args=scen.prep_args, - name=isnothing(scen.name) ? nothing : scen.name * " [zero]", + return Scenario{op, pl_op, pl_fun}(; + f = scen.f, + x = scen.x, + y = scen.y, + t = scen.t, + contexts = scen.contexts, + res1 = zero_res1, + res2 = zero_res2, + prep_args = scen.prep_args, + name = isnothing(scen.name) ? nothing : scen.name * " [zero]", ) end @@ -36,17 +36,17 @@ end Return a new `Scenario` identical to `scen` except for the function `f` which is changed to `new_f`. """ -function change_function(scen::Scenario{op,pl_op,pl_fun}, new_f) where {op,pl_op,pl_fun} - return Scenario{op,pl_op,pl_fun}(; - f=new_f, - x=scen.x, - y=scen.y, - t=scen.t, - contexts=scen.contexts, - res1=scen.res1, - res2=scen.res2, - prep_args=scen.prep_args, - name=isnothing(scen.name) ? nothing : scen.name * " [new function]", +function change_function(scen::Scenario{op, pl_op, pl_fun}, new_f) where {op, pl_op, pl_fun} + return Scenario{op, pl_op, pl_fun}(; + f = new_f, + x = scen.x, + y = scen.y, + t = scen.t, + contexts = scen.contexts, + res1 = scen.res1, + res2 = scen.res2, + prep_args = scen.prep_args, + name = isnothing(scen.name) ? nothing : scen.name * " [new function]", ) end @@ -59,53 +59,53 @@ Return a new `Scenario` identical to `scen` except for the tangents `tang` and a Only works if `scen` is a `pushforward`, `pullback` or `hvp` scenario. """ -function batchify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} +function batchify(scen::Scenario{op, pl_op, pl_fun}) where {op, pl_op, pl_fun} (; f, x, y, t, contexts, res1, res2, prep_args) = scen new_t = (only(t), -only(t)) new_prep_args = if pl_fun == :out (; - x=prep_args.x, - contexts=prep_args.contexts, - t=(only(prep_args.t), -only(prep_args.t)), + x = prep_args.x, + contexts = prep_args.contexts, + t = (only(prep_args.t), -only(prep_args.t)), ) else (; - y=prep_args.y, - x=prep_args.x, - contexts=prep_args.contexts, - t=(only(prep_args.t), -only(prep_args.t)), + y = prep_args.y, + x = prep_args.x, + contexts = prep_args.contexts, + t = (only(prep_args.t), -only(prep_args.t)), ) end if op == :pushforward || op == :pullback new_res1 = (only(res1), -only(res1)) - return Scenario{op,pl_op,pl_fun}(; + return Scenario{op, pl_op, pl_fun}(; f, x, y, - t=new_t, + t = new_t, contexts, - res1=new_res1, + res1 = new_res1, res2, - prep_args=new_prep_args, - name=isnothing(scen.name) ? nothing : scen.name * " [batchified]", + prep_args = new_prep_args, + name = isnothing(scen.name) ? nothing : scen.name * " [batchified]", ) elseif op == :hvp new_res2 = (only(res2), -only(res2)) - return Scenario{op,pl_op,pl_fun}(; + return Scenario{op, pl_op, pl_fun}(; f, x, y, - t=new_t, + t = new_t, contexts, res1, - res2=new_res2, - prep_args=new_prep_args, - name=isnothing(scen.name) ? nothing : scen.name * " [batchified]", + res2 = new_res2, + prep_args = new_prep_args, + name = isnothing(scen.name) ? nothing : scen.name * " [batchified]", ) end end -struct WritableClosure{pl_fun,F,X,Y} <: FunctionModifier +struct WritableClosure{pl_fun, F, X, Y} <: FunctionModifier f::F x_buffer::Vector{X} y_buffer::Vector{Y} @@ -114,9 +114,9 @@ struct WritableClosure{pl_fun,F,X,Y} <: FunctionModifier end function WritableClosure{pl_fun}( - f::F, x_buffer::Vector{X}, y_buffer::Vector{Y}, a, b -) where {pl_fun,F,X,Y} - return WritableClosure{pl_fun,F,X,Y}(f, x_buffer, y_buffer, a, b) + f::F, x_buffer::Vector{X}, y_buffer::Vector{Y}, a, b + ) where {pl_fun, F, X, Y} + return WritableClosure{pl_fun, F, X, Y}(f, x_buffer, y_buffer, a, b) end Base.show(io::IO, f::WritableClosure) = print(io, "WritableClosure($(f.f))") @@ -142,7 +142,7 @@ end Return a new `Scenario` identical to `scen` except for the function `f` which is made to close over differentiable data. """ -function closurify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} +function closurify(scen::Scenario{op, pl_op, pl_fun}) where {op, pl_op, pl_fun} (; f, x, y) = scen @assert isempty(scen.contexts) x_buffer = [zero(x)] @@ -150,24 +150,24 @@ function closurify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} a = 3.0 b = [4.0] closure_f = WritableClosure{pl_fun}(f, x_buffer, y_buffer, a, b) - return Scenario{op,pl_op,pl_fun}(; - f=closure_f, - x=scen.x, - y=mymultiply(scen.y, a + only(b)), - t=scen.t, - contexts=scen.contexts, - res1=mymultiply(scen.res1, a + only(b)), - res2=mymultiply(scen.res2, a + only(b)), - prep_args=scen.prep_args, - name=isnothing(scen.name) ? nothing : scen.name * " [closurified]", + return Scenario{op, pl_op, pl_fun}(; + f = closure_f, + x = scen.x, + y = mymultiply(scen.y, a + only(b)), + t = scen.t, + contexts = scen.contexts, + res1 = mymultiply(scen.res1, a + only(b)), + res2 = mymultiply(scen.res2, a + only(b)), + prep_args = scen.prep_args, + name = isnothing(scen.name) ? nothing : scen.name * " [closurified]", ) end -struct MultiplyByConstant{pl_fun,F} <: FunctionModifier +struct MultiplyByConstant{pl_fun, F} <: FunctionModifier f::F end -MultiplyByConstant{pl_fun}(f::F) where {pl_fun,F} = MultiplyByConstant{pl_fun,F}(f) +MultiplyByConstant{pl_fun}(f::F) where {pl_fun, F} = MultiplyByConstant{pl_fun, F}(f) Base.show(io::IO, f::MultiplyByConstant) = print(io, "MultiplyByConstant($(f.f))") @@ -188,30 +188,30 @@ end Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument by which the output is multiplied. The output and result fields are updated accordingly. """ -function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f,) = scen +function constantify(scen::Scenario{op, pl_op, pl_fun}) where {op, pl_op, pl_fun} + (; f) = scen @assert isempty(scen.contexts) multiply_f = MultiplyByConstant{pl_fun}(f) a = 3.0 - return Scenario{op,pl_op,pl_fun}(; - f=multiply_f, - x=scen.x, - y=mymultiply(scen.y, a), - t=scen.t, - contexts=(Constant(a),), - res1=mymultiply(scen.res1, a), - res2=mymultiply(scen.res2, a), - prep_args=(; scen.prep_args..., contexts=(Constant(-a),)), - name=isnothing(scen.name) ? nothing : scen.name * " [constantified]", + return Scenario{op, pl_op, pl_fun}(; + f = multiply_f, + x = scen.x, + y = mymultiply(scen.y, a), + t = scen.t, + contexts = (Constant(a),), + res1 = mymultiply(scen.res1, a), + res2 = mymultiply(scen.res2, a), + prep_args = (; scen.prep_args..., contexts = (Constant(-a),)), + name = isnothing(scen.name) ? nothing : scen.name * " [constantified]", ) end -struct StoreInCache{pl_fun,F} <: FunctionModifier +struct StoreInCache{pl_fun, F} <: FunctionModifier f::F end -function StoreInCache{pl_fun}(f::F) where {pl_fun,F} - return StoreInCache{pl_fun,F}(f) +function StoreInCache{pl_fun}(f::F) where {pl_fun, F} + return StoreInCache{pl_fun, F}(f) end Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") @@ -245,15 +245,15 @@ Return a new `Scenario` identical to `scen` except for the function `f`, which i If `tup=true` the cache is a tuple of arrays, otherwise just an array. """ -function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl_fun} - (; f,) = scen +function cachify(scen::Scenario{op, pl_op, pl_fun}; use_tuples) where {op, pl_op, pl_fun} + (; f) = scen @assert isempty(scen.contexts) cache_f = StoreInCache{pl_fun}(f) if use_tuples y_cache = if scen.y isa Number - (; useful_cache=([zero(scen.y)],), useless_cache=[zero(scen.y)]) + (; useful_cache = ([zero(scen.y)],), useless_cache = [zero(scen.y)]) else - (; useful_cache=(similar(scen.y),), useless_cache=similar(scen.y)) + (; useful_cache = (similar(scen.y),), useless_cache = similar(scen.y)) end else y_cache = if scen.y isa Number @@ -262,25 +262,25 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl similar(scen.y) end end - return Scenario{op,pl_op,pl_fun}(; - f=cache_f, - x=scen.x, - y=scen.y, - t=scen.t, - contexts=(Cache(y_cache),), - res1=scen.res1, - res2=scen.res2, - prep_args=(; scen.prep_args..., contexts=(Cache(y_cache),)), - name=isnothing(scen.name) ? nothing : scen.name * " [cachified]", + return Scenario{op, pl_op, pl_fun}(; + f = cache_f, + x = scen.x, + y = scen.y, + t = scen.t, + contexts = (Cache(y_cache),), + res1 = scen.res1, + res2 = scen.res2, + prep_args = (; scen.prep_args..., contexts = (Cache(y_cache),)), + name = isnothing(scen.name) ? nothing : scen.name * " [cachified]", ) end -struct MultiplyByConstantAndStoreInCache{pl_fun,F} <: FunctionModifier +struct MultiplyByConstantAndStoreInCache{pl_fun, F} <: FunctionModifier f::F end -function MultiplyByConstantAndStoreInCache{pl_fun}(f::F) where {pl_fun,F} - return MultiplyByConstantAndStoreInCache{pl_fun,F}(f) +function MultiplyByConstantAndStoreInCache{pl_fun}(f::F) where {pl_fun, F} + return MultiplyByConstantAndStoreInCache{pl_fun, F}(f) end function Base.show(io::IO, f::MultiplyByConstantAndStoreInCache) @@ -326,32 +326,32 @@ end Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional "constant or cache" argument. """ -function constantorcachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f,) = scen +function constantorcachify(scen::Scenario{op, pl_op, pl_fun}) where {op, pl_op, pl_fun} + (; f) = scen @assert isempty(scen.contexts) constantorcache_f = MultiplyByConstantAndStoreInCache{pl_fun}(f) a = 3.0 b = [4.0] constantorcache = if scen.y isa Number - (; cache=[zero(scen.y)], constant=(; a, b)) + (; cache = [zero(scen.y)], constant = (; a, b)) else - (; cache=similar(scen.y), constant=(; a, b)) + (; cache = similar(scen.y), constant = (; a, b)) end prep_constantorcache = if scen.y isa Number - (; cache=[zero(scen.y)], constant=(; a=2a, b=3b)) + (; cache = [zero(scen.y)], constant = (; a = 2a, b = 3b)) else - (; cache=similar(scen.y), constant=(; a=2a, b=3b)) + (; cache = similar(scen.y), constant = (; a = 2a, b = 3b)) end - return Scenario{op,pl_op,pl_fun}(; - f=constantorcache_f, - x=scen.x, - y=mymultiply(scen.y, a + only(b)), - t=scen.t, - contexts=(ConstantOrCache(constantorcache),), - res1=mymultiply(scen.res1, a + only(b)), - res2=mymultiply(scen.res2, a + only(b)), - prep_args=(; scen.prep_args..., contexts=(ConstantOrCache(prep_constantorcache),)), - name=isnothing(scen.name) ? nothing : scen.name * " [constantorcachified]", + return Scenario{op, pl_op, pl_fun}(; + f = constantorcache_f, + x = scen.x, + y = mymultiply(scen.y, a + only(b)), + t = scen.t, + contexts = (ConstantOrCache(constantorcache),), + res1 = mymultiply(scen.res1, a + only(b)), + res2 = mymultiply(scen.res2, a + only(b)), + prep_args = (; scen.prep_args..., contexts = (ConstantOrCache(prep_constantorcache),)), + name = isnothing(scen.name) ? nothing : scen.name * " [constantorcachified]", ) end @@ -429,8 +429,8 @@ Return a scenario identical to `scen` but where the first- and second-order resu Useful for comparison of outputs between backends. """ function compute_results( - scen::Scenario{op,pl_op,pl_fun}, backend::AbstractADType -) where {op,pl_op,pl_fun} + scen::Scenario{op, pl_op, pl_fun}, backend::AbstractADType + ) where {op, pl_op, pl_fun} (; f, y, x, t, contexts, prep_args, name) = deepcopy(scen) if pl_fun == :in if isnothing(t) @@ -449,8 +449,8 @@ function compute_results( new_res2 = get_res2(Val(op), f, backend, x, t, contexts...) end end - new_scen = Scenario{op,pl_op,pl_fun}(; - f, x, y, t, contexts, res1=new_res1, res2=new_res2, prep_args, name + new_scen = Scenario{op, pl_op, pl_fun}(; + f, x, y, t, contexts, res1 = new_res1, res2 = new_res2, prep_args, name ) return new_scen end diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index f092a9efd..dfa4e645e 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -33,7 +33,7 @@ Default values: $(TYPEDFIELDS) """ -struct Scenario{op,pl_op,pl_fun,F,X,Y,T<:Union{Nothing,NTuple},C<:Tuple,R1,R2,P<:NamedTuple} +struct Scenario{op, pl_op, pl_fun, F, X, Y, T <: Union{Nothing, NTuple}, C <: Tuple, R1, R2, P <: NamedTuple} "function `f` (if `pl_fun==:out`) or `f!` (if `pl_fun==:in`) to apply" f::F "primal output" @@ -51,23 +51,23 @@ struct Scenario{op,pl_op,pl_fun,F,X,Y,T<:Union{Nothing,NTuple},C<:Tuple,R1,R2,P< "named tuple of arguments passed to preparation, without the function - the required keys are a subset of `(; y, x, t, contexts)` depending on the operator" prep_args::P "name of the scenario for display in test sets and dataframes" - name::Union{String,Nothing} - - function Scenario{op,pl_op,pl_fun}(; - f::F, - y::Y, - x::X, - t::T, - contexts::C, - res1::R1, - res2::R2, - prep_args::P, - name::Union{String,Nothing}, - ) where {op,pl_op,pl_fun,F,X,Y,T,C,R1,R2,P} + name::Union{String, Nothing} + + function Scenario{op, pl_op, pl_fun}(; + f::F, + y::Y, + x::X, + t::T, + contexts::C, + res1::R1, + res2::R2, + prep_args::P, + name::Union{String, Nothing}, + ) where {op, pl_op, pl_fun, F, X, Y, T, C, R1, R2, P} @assert op in ALL_OPS @assert pl_op in (:in, :out) @assert pl_fun in (:in, :out) - return new{op,pl_op,pl_fun,F,X,Y,T,C,R1,R2,P}( + return new{op, pl_op, pl_fun, F, X, Y, T, C, R1, R2, P}( f, y, x, t, contexts, res1, res2, prep_args, name ) end @@ -78,78 +78,78 @@ function zero_contexts(contexts...) return rewrap(map(zero ∘ unwrap, contexts)...) end -function Scenario{op,pl_op}( - f, - x, - contexts::Vararg{Context}; - res1=nothing, - res2=nothing, - prep_args=(; x=zero(x), contexts=zero_contexts(contexts...)), - name=nothing, -) where {op,pl_op} +function Scenario{op, pl_op}( + f, + x, + contexts::Vararg{Context}; + res1 = nothing, + res2 = nothing, + prep_args = (; x = zero(x), contexts = zero_contexts(contexts...)), + name = nothing, + ) where {op, pl_op} y = f(x, map(unwrap, contexts)...) - return Scenario{op,pl_op,:out}(; - f, y, x, t=nothing, contexts, res1, res2, prep_args, name + return Scenario{op, pl_op, :out}(; + f, y, x, t = nothing, contexts, res1, res2, prep_args, name ) end -function Scenario{op,pl_op}( - f, - y, - x, - contexts::Vararg{Context}; - res1=nothing, - res2=nothing, - prep_args=(; y=zero(y), x=zero(x), contexts=zero_contexts(contexts...)), - name=nothing, -) where {op,pl_op} +function Scenario{op, pl_op}( + f, + y, + x, + contexts::Vararg{Context}; + res1 = nothing, + res2 = nothing, + prep_args = (; y = zero(y), x = zero(x), contexts = zero_contexts(contexts...)), + name = nothing, + ) where {op, pl_op} f(y, x, map(unwrap, contexts)...) - return Scenario{op,pl_op,:in}(; - f, y, x, t=nothing, contexts, res1, res2, prep_args, name + return Scenario{op, pl_op, :in}(; + f, y, x, t = nothing, contexts, res1, res2, prep_args, name ) end -function Scenario{op,pl_op}( - f, - x, - t::NTuple, - contexts::Vararg{Context}; - res1=nothing, - res2=nothing, - prep_args=(; x=zero(x), t=map(zero, t), contexts=zero_contexts(contexts...)), - name=nothing, -) where {op,pl_op} +function Scenario{op, pl_op}( + f, + x, + t::NTuple, + contexts::Vararg{Context}; + res1 = nothing, + res2 = nothing, + prep_args = (; x = zero(x), t = map(zero, t), contexts = zero_contexts(contexts...)), + name = nothing, + ) where {op, pl_op} y = f(x, map(unwrap, contexts)...) - return Scenario{op,pl_op,:out}(; f, y, x, t, contexts, res1, res2, prep_args, name) + return Scenario{op, pl_op, :out}(; f, y, x, t, contexts, res1, res2, prep_args, name) end -function Scenario{op,pl_op}( - f, - y, - x, - t::NTuple, - contexts::Vararg{Context}; - res1=nothing, - res2=nothing, - prep_args=(; y=zero(y), x=zero(x), t=map(zero, t), contexts=zero_contexts(contexts...)), - name=nothing, -) where {op,pl_op} +function Scenario{op, pl_op}( + f, + y, + x, + t::NTuple, + contexts::Vararg{Context}; + res1 = nothing, + res2 = nothing, + prep_args = (; y = zero(y), x = zero(x), t = map(zero, t), contexts = zero_contexts(contexts...)), + name = nothing, + ) where {op, pl_op} f(y, x, map(unwrap, contexts)...) - return Scenario{op,pl_op,:in}(; f, y, x, t, contexts, res1, res2, prep_args, name) + return Scenario{op, pl_op, :in}(; f, y, x, t, contexts, res1, res2, prep_args, name) end Base.:(==)(scen1::Scenario, scen2::Scenario) = false function Base.:(==)( - scen1::Scenario{op,pl_op,pl_fun}, scen2::Scenario{op,pl_op,pl_fun} -) where {op,pl_op,pl_fun} + scen1::Scenario{op, pl_op, pl_fun}, scen2::Scenario{op, pl_op, pl_fun} + ) where {op, pl_op, pl_fun} eq_f = scen1.f == scen2.f eq_x = scen1.x == scen2.x eq_y = scen1.y == scen2.y eq_t = scen1.t == scen2.t eq_contexts = all( map(scen1.contexts, scen2.contexts) do c1, c2 - if c1 isa Union{Cache,ConstantOrCache} || c2 isa Union{Cache,ConstantOrCache} + if c1 isa Union{Cache, ConstantOrCache} || c2 isa Union{Cache, ConstantOrCache} return true else return c1 == c2 @@ -163,8 +163,8 @@ function Base.:(==)( end operator(::Scenario{op}) where {op} = op -operator_place(::Scenario{op,pl_op}) where {op,pl_op} = pl_op -function_place(::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} = pl_fun +operator_place(::Scenario{op, pl_op}) where {op, pl_op} = pl_op +function_place(::Scenario{op, pl_op, pl_fun}) where {op, pl_op, pl_fun} = pl_fun function order(scen::Scenario) if operator(scen) in [:pushforward, :pullback, :derivative, :gradient, :jacobian] @@ -178,25 +178,25 @@ function compatible(backend::AbstractADType, scen::Scenario) place_compatible = function_place(scen) == :out || Bool(inplace_support(backend)) sparse_compatible = operator(scen) in (:jacobian, :hessian) || !isa(backend, AutoSparse) secondorder_compatible = - order(scen) == 2 || !isa(backend, Union{SecondOrder,AutoSparse{<:SecondOrder}}) + order(scen) == 2 || !isa(backend, Union{SecondOrder, AutoSparse{<:SecondOrder}}) mixedmode_compatible = operator(scen) == :jacobian || !isa(backend, AutoSparse{<:MixedMode}) return place_compatible && - secondorder_compatible && - sparse_compatible && - mixedmode_compatible + secondorder_compatible && + sparse_compatible && + mixedmode_compatible end function group_by_operator(scenarios::AbstractVector{<:Scenario}) return Dict( op => filter(s -> operator(s) == op, scenarios) for - op in unique(operator.(scenarios)) + op in unique(operator.(scenarios)) ) end function Base.show( - io::IO, scen::Scenario{op,pl_op,pl_fun,F,X,Y,T} -) where {op,pl_op,pl_fun,F,X,Y,T} + io::IO, scen::Scenario{op, pl_op, pl_fun, F, X, Y, T} + ) where {op, pl_op, pl_fun, F, X, Y, T} if isnothing(scen.name) print(io, "Scenario{$(repr(op)),$(repr(pl_op))} $(string(scen.f)) : $X -> $Y") if op in (:pushforward, :pullback, :hvp) diff --git a/DifferentiationInterfaceTest/src/scenarios/sparse.jl b/DifferentiationInterfaceTest/src/scenarios/sparse.jl index 3c914c053..1cee484d1 100644 --- a/DifferentiationInterfaceTest/src/scenarios/sparse.jl +++ b/DifferentiationInterfaceTest/src/scenarios/sparse.jl @@ -41,11 +41,11 @@ function sparse_vec_to_vec_scenarios(x::AbstractVector) append!( scens, [ - Scenario{:jacobian,pl_op}( - f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f, x; prep_args = (; x = x_prep, contexts = ()), res1 = jac ), - Scenario{:jacobian,pl_op}( - f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f!, y, x; prep_args = (; y = zero(y), x = x_prep, contexts = ()), res1 = jac ), ], ) @@ -82,11 +82,11 @@ function sparse_mat_to_vec_scenarios(x::AbstractMatrix) append!( scens, [ - Scenario{:jacobian,pl_op}( - f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f, x; prep_args = (; x = x_prep, contexts = ()), res1 = jac ), - Scenario{:jacobian,pl_op}( - f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f!, y, x; prep_args = (; y = zero(y), x = x_prep, contexts = ()), res1 = jac ), ], ) @@ -120,11 +120,11 @@ function sparse_vec_to_mat_scenarios(x::AbstractVector) append!( scens, [ - Scenario{:jacobian,pl_op}( - f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f, x; prep_args = (; x = x_prep, contexts = ()), res1 = jac ), - Scenario{:jacobian,pl_op}( - f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f!, y, x; prep_args = (; y = zero(y), x = x_prep, contexts = ()), res1 = jac ), ], ) @@ -160,11 +160,11 @@ function sparse_mat_to_mat_scenarios(x::AbstractMatrix) append!( scens, [ - Scenario{:jacobian,pl_op}( - f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f, x; prep_args = (; x = x_prep, contexts = ()), res1 = jac ), - Scenario{:jacobian,pl_op}( - f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f!, y, x; prep_args = (; y = zero(y), x = x_prep, contexts = ()), res1 = jac ), ], ) @@ -207,8 +207,8 @@ function sparse_vec_to_num_scenarios(x::AbstractVector) append!( scens, [ - Scenario{:hessian,pl_op}( - f, x; prep_args=(; x=x_prep, contexts=()), res1=grad, res2=hess + Scenario{:hessian, pl_op}( + f, x; prep_args = (; x = x_prep, contexts = ()), res1 = grad, res2 = hess ), ], ) @@ -239,8 +239,8 @@ function sparse_mat_to_num_scenarios(x::AbstractMatrix) append!( scens, [ - Scenario{:hessian,pl_op}( - f, x; prep_args=(; x=x_prep, contexts=()), res1=grad, res2=hess + Scenario{:hessian, pl_op}( + f, x; prep_args = (; x = x_prep, contexts = ()), res1 = grad, res2 = hess ), ], ) @@ -258,7 +258,7 @@ end ### Linear map -struct SquareLinearMap{M<:AbstractMatrix} +struct SquareLinearMap{M <: AbstractMatrix} A::M end @@ -292,11 +292,11 @@ function squarelinearmap_scenarios(x::AbstractVector, band_sizes) append!( scens, [ - Scenario{:jacobian,pl_op}( - f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f, x; prep_args = (; x = x_prep, contexts = ()), res1 = jac ), - Scenario{:jacobian,pl_op}( - f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + Scenario{:jacobian, pl_op}( + f!, y, x; prep_args = (; y = zero(y), x = x_prep, contexts = ()), res1 = jac ), ], ) @@ -307,7 +307,7 @@ end ### Quadratic form -struct SquareQuadraticForm{M<:AbstractMatrix} +struct SquareQuadraticForm{M <: AbstractMatrix} A::M end @@ -355,8 +355,8 @@ function squarequadraticform_scenarios(x::AbstractVector, band_sizes) for pl_op in (:out, :in) push!( scens, - Scenario{:hessian,pl_op}( - f, x; prep_args=(; x=x_prep, contexts=()), res1=grad, res2=hess + Scenario{:hessian, pl_op}( + f, x; prep_args = (; x = x_prep, contexts = ()), res1 = grad, res2 = hess ), ) end @@ -372,12 +372,12 @@ end Create a vector of [`Scenario`](@ref)s with sparse array types, focused on sparse Jacobians and Hessians. """ function sparse_scenarios(; - band_sizes=[5, 10, 20], - include_constantified=false, - include_cachified=false, - include_constantorcachified=false, - use_tuples=false, -) + band_sizes = [5, 10, 20], + include_constantified = false, + include_cachified = false, + include_constantorcachified = false, + use_tuples = false, + ) x_6 = float.(1:6) x_2_3 = float.(reshape(1:6, 2, 3)) x_50 = float.(range(1, 2, 50)) diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index fcc038ddf..daf59b405 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -67,48 +67,48 @@ Each setting tests/benchmarks a different subset of calls: - `adaptive_batchsize=true`: whether to cap the backend's preset batch size (when it exists) to prevent errors on small inputs """ function test_differentiation( - backends::Vector{<:AbstractADType}, - scenarios::Vector{<:Scenario}=default_scenarios(); - testset_name::Union{String,Nothing}=nothing, - # test categories - correctness::Bool=true, - type_stability::Symbol=:none, - allocations::Symbol=:none, - benchmark::Symbol=:none, - # misc options - excluded::Vector{Symbol}=Symbol[], - detailed::Bool=false, - logging::Bool=false, - # correctness options - isapprox=isapprox, - atol::Real=0, - rtol::Real=1e-3, - scenario_intact::Bool=true, - sparsity::Bool=false, - reprepare::Bool=true, - # type stability options - ignored_modules=nothing, - function_filter=if VERSION >= v"1.11" - @nospecialize(f) -> true - else - @nospecialize(f) -> f != Base.mapreduce_empty # fix for `mapreduce` in jacobian and hessian - end, - # allocs options - skip_allocations::Bool=false, # private, only for code coverage - # benchmark options - count_calls::Bool=true, - benchmark_test::Bool=true, - benchmark_seconds::Real=1, - benchmark_aggregation=minimum, - # batch size - adaptive_batchsize::Bool=true, -) + backends::Vector{<:AbstractADType}, + scenarios::Vector{<:Scenario} = default_scenarios(); + testset_name::Union{String, Nothing} = nothing, + # test categories + correctness::Bool = true, + type_stability::Symbol = :none, + allocations::Symbol = :none, + benchmark::Symbol = :none, + # misc options + excluded::Vector{Symbol} = Symbol[], + detailed::Bool = false, + logging::Bool = false, + # correctness options + isapprox = isapprox, + atol::Real = 0, + rtol::Real = 1.0e-3, + scenario_intact::Bool = true, + sparsity::Bool = false, + reprepare::Bool = true, + # type stability options + ignored_modules = nothing, + function_filter = if VERSION >= v"1.11" + @nospecialize(f) -> true + else + @nospecialize(f) -> f != Base.mapreduce_empty # fix for `mapreduce` in jacobian and hessian + end, + # allocs options + skip_allocations::Bool = false, # private, only for code coverage + # benchmark options + count_calls::Bool = true, + benchmark_test::Bool = true, + benchmark_seconds::Real = 1, + benchmark_aggregation = minimum, + # batch size + adaptive_batchsize::Bool = true, + ) @assert type_stability in (:none, :prepared, :full) @assert allocations in (:none, :prepared, :full) @assert benchmark in (:none, :prepared, :full) scenarios = filter(s -> !(operator(s) in excluded), scenarios) - scenarios = sort(scenarios; by=s -> (operator(s), string(s.f))) + scenarios = sort(scenarios; by = s -> (operator(s), string(s.f))) if isnothing(testset_name) title_additions = @@ -122,18 +122,18 @@ function test_differentiation( benchmark_data = DifferentiationBenchmarkDataRow[] - prog = ProgressUnknown(; desc="$title", spinner=true, enabled=logging) + prog = ProgressUnknown(; desc = "$title", spinner = true, enabled = logging) @testset verbose = true "$title" begin @testset verbose = detailed "$backend" for (i, backend) in enumerate(backends) filtered_scenarios = filter(s -> compatible(backend, s), scenarios) grouped_scenarios = group_by_operator(filtered_scenarios) @testset verbose = detailed "$op" for (j, (op, op_group)) in - enumerate(pairs(grouped_scenarios)) + enumerate(pairs(grouped_scenarios)) @testset "$scen" for (k, scen) in enumerate(op_group) next!( prog; - showvalues=[ + showvalues = [ (:backend, "$backend - $i/$(length(backends))"), (:scenario_type, "$op - $j/$(length(grouped_scenarios))"), (:scenario, "$k/$(length(op_group))"), @@ -171,7 +171,7 @@ function test_differentiation( test_jet( adapted_backend, scen; - subset=type_stability, + subset = type_stability, ignored_modules, function_filter, ) @@ -181,8 +181,8 @@ function test_differentiation( test_alloccheck( adapted_backend, scen; - subset=allocations, - skip=skip_allocations, + subset = allocations, + skip = skip_allocations, ) end yield() @@ -192,7 +192,7 @@ function test_differentiation( adapted_backend, scen; logging, - subset=benchmark, + subset = benchmark, count_calls, benchmark_test, benchmark_seconds, @@ -228,26 +228,26 @@ Shortcut for [`test_differentiation`](@ref) with only benchmarks and no correctn Specifying the set of scenarios is mandatory for this function. """ function benchmark_differentiation( - backends, - scenarios::Vector{<:Scenario}; - testset_name::Union{String,Nothing}=nothing, - benchmark::Symbol=:prepared, - excluded::Vector{Symbol}=Symbol[], - logging::Bool=false, - count_calls::Bool=true, - benchmark_test::Bool=true, - benchmark_seconds::Real=1, - benchmark_aggregation=minimum, - # batch size - adaptive_batchsize::Bool=true, -) + backends, + scenarios::Vector{<:Scenario}; + testset_name::Union{String, Nothing} = nothing, + benchmark::Symbol = :prepared, + excluded::Vector{Symbol} = Symbol[], + logging::Bool = false, + count_calls::Bool = true, + benchmark_test::Bool = true, + benchmark_seconds::Real = 1, + benchmark_aggregation = minimum, + # batch size + adaptive_batchsize::Bool = true, + ) return test_differentiation( backends, scenarios; testset_name, - correctness=false, - type_stability=:none, - allocations=:none, + correctness = false, + type_stability = :none, + allocations = :none, benchmark, logging, excluded, diff --git a/DifferentiationInterfaceTest/src/tests/allocs_eval.jl b/DifferentiationInterfaceTest/src/tests/allocs_eval.jl index f30454de3..2061d1872 100644 --- a/DifferentiationInterfaceTest/src/tests/allocs_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/allocs_eval.jl @@ -1,5 +1,5 @@ function test_noallocs(skip::Bool, func, args...) - possible_allocations = check_allocs(func, typeof.(args); ignore_throw=true) + possible_allocations = check_allocs(func, typeof.(args); ignore_throw = true) return @test isempty(possible_allocations) skip = skip end @@ -18,15 +18,15 @@ for op in ALL_OPS val_and_op! = Symbol(val_prefix, op!) prep_op = Symbol("prepare_", op) - S1out = Scenario{op,:out,:out} - S1in = Scenario{op,:in,:out} - S2out = Scenario{op,:out,:in} - S2in = Scenario{op,:in,:in} + S1out = Scenario{op, :out, :out} + S1in = Scenario{op, :in, :out} + S2out = Scenario{op, :out, :in} + S2in = Scenario{op, :in, :in} if op in [:derivative, :gradient, :jacobian] @eval function test_alloccheck( - ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool + ) (; f, x, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @@ -40,8 +40,8 @@ for op in ALL_OPS end @eval function test_alloccheck( - ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool + ) (; f, x, res1, contexts, prep_args) = deepcopy(scen) res1_sim = mysimilar(res1) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) @@ -61,8 +61,8 @@ for op in ALL_OPS op == :gradient && continue @eval function test_alloccheck( - ba::AbstractADType, scen::$S2out; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S2out; subset::Symbol, skip::Bool + ) (; f, x, y, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) (subset == :full) && test_noallocs( @@ -77,8 +77,8 @@ for op in ALL_OPS end @eval function test_alloccheck( - ba::AbstractADType, scen::$S2in; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S2in; subset::Symbol, skip::Bool + ) (; f, x, y, res1, contexts, prep_args) = deepcopy(scen) res1_sim = mysimilar(res1) prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) @@ -98,8 +98,8 @@ for op in ALL_OPS elseif op in [:second_derivative, :hessian] @eval function test_alloccheck( - ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool + ) (; f, x, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @@ -113,8 +113,8 @@ for op in ALL_OPS end @eval function test_alloccheck( - ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool + ) (; f, x, res1, res2, contexts, prep_args) = deepcopy(scen) res1_sim, res2_sim = mysimilar(res1), mysimilar(res2) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) @@ -133,8 +133,8 @@ for op in ALL_OPS elseif op in [:pushforward, :pullback] @eval function test_alloccheck( - ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool + ) (; f, x, t, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && test_noallocs( @@ -149,8 +149,8 @@ for op in ALL_OPS end @eval function test_alloccheck( - ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool + ) (; f, x, t, res1, contexts, prep_args) = deepcopy(scen) res1_sim = mysimilar(res1) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) @@ -169,8 +169,8 @@ for op in ALL_OPS end @eval function test_alloccheck( - ba::AbstractADType, scen::$S2out; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S2out; subset::Symbol, skip::Bool + ) (; f, x, y, t, contexts, prep_args) = deepcopy(scen) prep = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... @@ -195,8 +195,8 @@ for op in ALL_OPS end @eval function test_alloccheck( - ba::AbstractADType, scen::$S2in; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S2in; subset::Symbol, skip::Bool + ) (; f, x, y, t, res1, contexts, prep_args) = deepcopy(scen) res1_sim = mysimilar(res1) prep = $prep_op( @@ -226,8 +226,8 @@ for op in ALL_OPS elseif op in [:hvp] @eval function test_alloccheck( - ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool + ) (; f, x, t, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && test_noallocs( @@ -242,8 +242,8 @@ for op in ALL_OPS end @eval function test_alloccheck( - ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool - ) + ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool + ) (; f, x, t, res1, res2, contexts, prep_args) = deepcopy(scen) res1_sim, res2_sim = mysimilar(res1), mysimilar(res2) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) diff --git a/DifferentiationInterfaceTest/src/tests/benchmark.jl b/DifferentiationInterfaceTest/src/tests/benchmark.jl index fca1ce77a..d71a54383 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark.jl @@ -64,7 +64,7 @@ Base.@kwdef struct DifferentiationBenchmarkDataRow{T} "differentiation operator used for benchmarking, e.g. `:gradient` or `:hessian`" operator::Symbol "whether the operator had been prepared" - prepared::Union{Nothing,Bool} + prepared::Union{Nothing, Bool} "number of calls to the differentiated function for one call to the operator" calls::Int "number of benchmarking samples taken" @@ -84,28 +84,28 @@ Base.@kwdef struct DifferentiationBenchmarkDataRow{T} end function record!( - data::Vector{DifferentiationBenchmarkDataRow}; - backend::AbstractADType, - scenario::Scenario, - operator::String, - prepared::Union{Nothing,Bool}, - bench::Benchmark, - calls::Integer, - aggregation, -) + data::Vector{DifferentiationBenchmarkDataRow}; + backend::AbstractADType, + scenario::Scenario, + operator::String, + prepared::Union{Nothing, Bool}, + bench::Benchmark, + calls::Integer, + aggregation, + ) row = DifferentiationBenchmarkDataRow(; - backend=backend, - scenario=scenario, - operator=Symbol(operator), - prepared=prepared, - calls=calls, - samples=length(bench.samples), - evals=Int(bench.samples[1].evals), - time=aggregation(getfield.(bench.samples, :time)), - allocs=aggregation(getfield.(bench.samples, :allocs)), - bytes=aggregation(getfield.(bench.samples, :bytes)), - gc_fraction=aggregation(getfield.(bench.samples, :gc_fraction)), - compile_fraction=aggregation(getfield.(bench.samples, :compile_fraction)), + backend = backend, + scenario = scenario, + operator = Symbol(operator), + prepared = prepared, + calls = calls, + samples = length(bench.samples), + evals = Int(bench.samples[1].evals), + time = aggregation(getfield.(bench.samples, :time)), + allocs = aggregation(getfield.(bench.samples, :allocs)), + bytes = aggregation(getfield.(bench.samples, :bytes)), + gc_fraction = aggregation(getfield.(bench.samples, :gc_fraction)), + compile_fraction = aggregation(getfield.(bench.samples, :compile_fraction)), ) return push!(data, row) end diff --git a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl index 64464ba2c..b8a01fc29 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl @@ -29,27 +29,27 @@ for op in ALL_OPS val_and_op! = Symbol(val_prefix, op!) prep_op = Symbol("prepare_", op) - S1out = Scenario{op,:out,:out} - S1in = Scenario{op,:in,:out} - S2out = Scenario{op,:out,:in} - S2in = Scenario{op,:in,:in} + S1out = Scenario{op, :out, :out} + S1in = Scenario{op, :in, :out} + S2out = Scenario{op, :out, :in} + S2in = Scenario{op, :in, :in} @eval function run_benchmark!( - data::Vector{DifferentiationBenchmarkDataRow}, - backend::AbstractADType, - scenario::Union{$S1out,$S1in,$S2out,$S2in}; - logging::Bool, - subset::Symbol, - count_calls::Bool, - benchmark_test::Bool, - benchmark_seconds::Real, - benchmark_aggregation, - ) + data::Vector{DifferentiationBenchmarkDataRow}, + backend::AbstractADType, + scenario::Union{$S1out, $S1in, $S2out, $S2in}; + logging::Bool, + subset::Symbol, + count_calls::Bool, + benchmark_test::Bool, + benchmark_seconds::Real, + benchmark_aggregation, + ) @assert subset in (:full, :prepared) bench_success = true bench_result = try - benchmark_aux(backend, scenario; subset, s=benchmark_seconds) + benchmark_aux(backend, scenario; subset, s = benchmark_seconds) catch exception bench_success = false logging && @warn "Error during benchmarking" backend scenario exception @@ -60,7 +60,7 @@ for op in ALL_OPS if count_calls count_success = true calls_result = try - calls_aux(backend, scenario; subset, s=nothing) + calls_aux(backend, scenario; subset, s = nothing) catch exception count_success = false logging && @warn "Error during call counting" backend scenario exception @@ -72,7 +72,7 @@ for op in ALL_OPS end prep_string = $(string(prep_op)) - if scenario isa Union{$S1out,$S2out} + if scenario isa Union{$S1out, $S2out} valop_string = $(string(val_and_op)) op_string = $(string(op)) else @@ -84,52 +84,52 @@ for op in ALL_OPS data; backend, scenario, - operator=valop_string, - prepared=true, - bench=bench_result.prepared_valop, - calls=calls_result.prepared_valop, - aggregation=benchmark_aggregation, + operator = valop_string, + prepared = true, + bench = bench_result.prepared_valop, + calls = calls_result.prepared_valop, + aggregation = benchmark_aggregation, ) record!( data; backend, scenario, - operator=op_string, - prepared=true, - bench=bench_result.prepared_op, - calls=calls_result.prepared_op, - aggregation=benchmark_aggregation, + operator = op_string, + prepared = true, + bench = bench_result.prepared_op, + calls = calls_result.prepared_op, + aggregation = benchmark_aggregation, ) if subset == :full record!( data; backend, scenario, - operator=prep_string, - prepared=nothing, - bench=bench_result.preparation, - calls=calls_result.preparation, - aggregation=benchmark_aggregation, + operator = prep_string, + prepared = nothing, + bench = bench_result.preparation, + calls = calls_result.preparation, + aggregation = benchmark_aggregation, ) record!( data; backend, scenario, - operator=valop_string, - prepared=false, - bench=bench_result.unprepared_valop, - calls=calls_result.unprepared_valop, - aggregation=benchmark_aggregation, + operator = valop_string, + prepared = false, + bench = bench_result.unprepared_valop, + calls = calls_result.unprepared_valop, + aggregation = benchmark_aggregation, ) record!( data; backend, scenario, - operator=op_string, - prepared=false, - bench=bench_result.unprepared_op, - calls=calls_result.unprepared_op, - aggregation=benchmark_aggregation, + operator = op_string, + prepared = false, + bench = bench_result.unprepared_op, + calls = calls_result.unprepared_op, + aggregation = benchmark_aggregation, ) end return nothing diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index a19c256df..ef76414b9 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -1,4 +1,4 @@ -has_size(::Union{Number,AbstractArray}) = true +has_size(::Union{Number, AbstractArray}) = true has_size(_x) = false function should_reprepare(scen) @@ -40,28 +40,28 @@ for op in ALL_OPS SecondDerivativePrep end - S1out = Scenario{op,:out,:out} - S1in = Scenario{op,:in,:out} - S2out = Scenario{op,:out,:in} - S2in = Scenario{op,:in,:in} + S1out = Scenario{op, :out, :out} + S1in = Scenario{op, :in, :out} + S2out = Scenario{op, :out, :in} + S2in = Scenario{op, :in, :in} if op in [:derivative, :gradient, :jacobian] @eval function test_correctness( - ba::AbstractADType, - scen::$S1out; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S1out; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep0 = $prep_op(f, ba, prep_args.x, prep_args.contexts...) prep_nostrict0 = $prep_op( - f, ba, prep_args.x, prep_args.contexts...; strict=Val(false) + f, ba, prep_args.x, prep_args.contexts...; strict = Val(false) ) if reprepare && should_reprepare(scen) prep = $prep_op!(f, prep0, ba, x, contexts...) @@ -103,21 +103,21 @@ for op in ALL_OPS end @eval function test_correctness( - ba::AbstractADType, - scen::$S1in; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S1in; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep0 = $prep_op(f, ba, prep_args.x, prep_args.contexts...) prep_nostrict0 = $prep_op( - f, ba, prep_args.x, prep_args.contexts...; strict=Val(false) + f, ba, prep_args.x, prep_args.contexts...; strict = Val(false) ) if reprepare && should_reprepare(scen) prep = $prep_op!(f, prep0, ba, x, contexts...) @@ -173,15 +173,15 @@ for op in ALL_OPS op == :gradient && continue @eval function test_correctness( - ba::AbstractADType, - scen::$S2out; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S2out; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ @@ -192,12 +192,12 @@ for op in ALL_OPS ba, prep_args.x, prep_args.contexts...; - strict=Val(false), + strict = Val(false), ) if reprepare && - has_size(x) && - has_size(y) && - (size(x) != size(prep_args.x) || size(y) != prep_args.y) + has_size(x) && + has_size(y) && + (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep0, ba, x, contexts...) prep_nostrict = $prep_op!(f, y, prep_nostrict0, ba, x, contexts...) else @@ -243,15 +243,15 @@ for op in ALL_OPS end @eval function test_correctness( - ba::AbstractADType, - scen::$S2in; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S2in; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ @@ -262,12 +262,12 @@ for op in ALL_OPS ba, prep_args.x, prep_args.contexts...; - strict=Val(false), + strict = Val(false), ) if reprepare && - has_size(x) && - has_size(y) && - (size(x) != size(prep_args.x) || size(y) != prep_args.y) + has_size(x) && + has_size(y) && + (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep0, ba, x, contexts...) prep_nostrict = $prep_op!(f, y, prep_nostrict0, ba, x, contexts...) else @@ -322,21 +322,21 @@ for op in ALL_OPS elseif op in [:second_derivative, :hessian] @eval function test_correctness( - ba::AbstractADType, - scen::$S1out; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S1out; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep0 = $prep_op(f, ba, prep_args.x, prep_args.contexts...) prep_nostrict0 = $prep_op( - f, ba, prep_args.x, prep_args.contexts...; strict=Val(false) + f, ba, prep_args.x, prep_args.contexts...; strict = Val(false) ) if reprepare && should_reprepare(scen) prep = $prep_op!(f, prep0, ba, x, contexts...) @@ -380,21 +380,21 @@ for op in ALL_OPS end @eval function test_correctness( - ba::AbstractADType, - scen::$S1in; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S1in; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep0 = $prep_op(f, ba, prep_args.x, prep_args.contexts...) prep_nostrict0 = $prep_op( - f, ba, prep_args.x, prep_args.contexts...; strict=Val(false) + f, ba, prep_args.x, prep_args.contexts...; strict = Val(false) ) if reprepare && should_reprepare(scen) prep = $prep_op!(f, prep0, ba, x, contexts...) @@ -453,15 +453,15 @@ for op in ALL_OPS elseif op in [:pushforward, :pullback] @eval function test_correctness( - ba::AbstractADType, - scen::$S1out; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S1out; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ @@ -472,7 +472,7 @@ for op in ALL_OPS prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(false), + strict = Val(false), ) prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && should_reprepare(scen) @@ -511,15 +511,15 @@ for op in ALL_OPS end @eval function test_correctness( - ba::AbstractADType, - scen::$S1in; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S1in; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ @@ -530,7 +530,7 @@ for op in ALL_OPS prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(false), + strict = Val(false), ) prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && should_reprepare(scen) @@ -581,15 +581,15 @@ for op in ALL_OPS end @eval function test_correctness( - ba::AbstractADType, - scen::$S2out; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S2out; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ @@ -603,13 +603,13 @@ for op in ALL_OPS prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(false), + strict = Val(false), ) prep_same = $prep_op_same(f, y, ba, x, map(zero, t), contexts...) if reprepare && - has_size(x) && - has_size(y) && - (size(x) != size(prep_args.x) || size(y) != prep_args.y) + has_size(x) && + has_size(y) && + (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep0, ba, x, t, contexts...) prep_nostrict = $prep_op!(f, y, prep_nostrict0, ba, x, t, contexts...) else @@ -655,15 +655,15 @@ for op in ALL_OPS end @eval function test_correctness( - ba::AbstractADType, - scen::$S2in; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S2in; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ @@ -677,13 +677,13 @@ for op in ALL_OPS prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(false), + strict = Val(false), ) prep_same = $prep_op_same(f, y, ba, x, map(zero, t), contexts...) if reprepare && - has_size(x) && - has_size(y) && - (size(x) != size(prep_args.x) || size(y) != prep_args.y) + has_size(x) && + has_size(y) && + (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep0, ba, x, t, contexts...) prep_nostrict = $prep_op!(f, y, prep_nostrict0, ba, x, t, contexts...) else @@ -734,15 +734,15 @@ for op in ALL_OPS elseif op in [:hvp] @eval function test_correctness( - ba::AbstractADType, - scen::$S1out; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S1out; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ @@ -753,7 +753,7 @@ for op in ALL_OPS prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(false), + strict = Val(false), ) prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && should_reprepare(scen) @@ -792,15 +792,15 @@ for op in ALL_OPS end @eval function test_correctness( - ba::AbstractADType, - scen::$S1in; - isapprox::Function, - atol::Real, - rtol::Real, - scenario_intact::Bool, - sparsity::Bool, - reprepare::Bool, - ) + ba::AbstractADType, + scen::$S1in; + isapprox::Function, + atol::Real, + rtol::Real, + scenario_intact::Bool, + sparsity::Bool, + reprepare::Bool, + ) ≈(x, y) = isapprox(x, y; atol, rtol) (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) preptup_cands_val, preptup_cands_noval = map(1:2) do _ @@ -811,7 +811,7 @@ for op in ALL_OPS prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(false), + strict = Val(false), ) prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && should_reprepare(scen) diff --git a/DifferentiationInterfaceTest/src/tests/prep_eval.jl b/DifferentiationInterfaceTest/src/tests/prep_eval.jl index 52ad19473..713fcf5e7 100644 --- a/DifferentiationInterfaceTest/src/tests/prep_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/prep_eval.jl @@ -35,13 +35,13 @@ for op in ALL_OPS SecondDerivativePrep end - S1out = Scenario{op,:out,:out} - S1in = Scenario{op,:in,:out} - S2out = Scenario{op,:out,:in} - S2in = Scenario{op,:in,:in} + S1out = Scenario{op, :out, :out} + S1in = Scenario{op, :in, :out} + S2out = Scenario{op, :out, :in} + S2in = Scenario{op, :in, :in} if op in [:derivative, :gradient, :jacobian] - @eval function test_prep(ba::AbstractADType, scen::$S1out;) + @eval function test_prep(ba::AbstractADType, scen::$S1out) (; f, x, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) @test prep isa $P @@ -50,7 +50,7 @@ for op in ALL_OPS return nothing end - @eval function test_prep(ba::AbstractADType, scen::$S1in;) + @eval function test_prep(ba::AbstractADType, scen::$S1in) (; f, x, res1, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) @test prep isa $P @@ -63,7 +63,7 @@ for op in ALL_OPS op == :gradient && continue - @eval function test_prep(ba::AbstractADType, scen::$S2out;) + @eval function test_prep(ba::AbstractADType, scen::$S2out) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) @test prep isa $P @@ -72,7 +72,7 @@ for op in ALL_OPS return nothing end - @eval function test_prep(ba::AbstractADType, scen::$S2in;) + @eval function test_prep(ba::AbstractADType, scen::$S2in) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) @test prep isa $P @@ -86,7 +86,7 @@ for op in ALL_OPS end elseif op in [:second_derivative, :hessian] - @eval function test_prep(ba::AbstractADType, scen::$S1out;) + @eval function test_prep(ba::AbstractADType, scen::$S1out) (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) @test prep isa $P @@ -95,7 +95,7 @@ for op in ALL_OPS return nothing end - @eval function test_prep(ba::AbstractADType, scen::$S1in;) + @eval function test_prep(ba::AbstractADType, scen::$S1in) (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) @test prep isa $P @@ -107,7 +107,7 @@ for op in ALL_OPS end elseif op in [:pushforward, :pullback] - @eval function test_prep(ba::AbstractADType, scen::$S1out;) + @eval function test_prep(ba::AbstractADType, scen::$S1out) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) @test prep isa $P @@ -116,7 +116,7 @@ for op in ALL_OPS return nothing end - @eval function test_prep(ba::AbstractADType, scen::$S1in;) + @eval function test_prep(ba::AbstractADType, scen::$S1in) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) @test prep isa $P @@ -127,7 +127,7 @@ for op in ALL_OPS return nothing end - @eval function test_prep(ba::AbstractADType, scen::$S2out;) + @eval function test_prep(ba::AbstractADType, scen::$S2out) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... @@ -138,7 +138,7 @@ for op in ALL_OPS return nothing end - @eval function test_prep(ba::AbstractADType, scen::$S2in;) + @eval function test_prep(ba::AbstractADType, scen::$S2in) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... @@ -154,7 +154,7 @@ for op in ALL_OPS end elseif op in [:hvp] - @eval function test_prep(ba::AbstractADType, scen::$S1out;) + @eval function test_prep(ba::AbstractADType, scen::$S1out) (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) @test prep isa $P @@ -163,7 +163,7 @@ for op in ALL_OPS return nothing end - @eval function test_prep(ba::AbstractADType, scen::$S1in;) + @eval function test_prep(ba::AbstractADType, scen::$S1in) (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) @test prep isa $P diff --git a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl index 5219bf410..911f6463e 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl @@ -13,376 +13,376 @@ for op in ALL_OPS val_and_op! = Symbol(val_prefix, op!) prep_op = Symbol("prepare_", op) - S1out = Scenario{op,:out,:out} - S1in = Scenario{op,:in,:out} - S2out = Scenario{op,:out,:in} - S2in = Scenario{op,:in,:in} + S1out = Scenario{op, :out, :out} + S1in = Scenario{op, :in, :out} + S2out = Scenario{op, :out, :in} + S2in = Scenario{op, :in, :in} if op in [:derivative, :gradient, :jacobian] @eval function test_jet( - ba::AbstractADType, - scen::$S1out; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) + function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, ba, x, contexts...) + function_filter $op(f, ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, ba, x, contexts...) + function_filter $val_and_op(f, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, prep, ba, x, contexts...) + function_filter $op(f, prep, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, prep, ba, x, contexts...) + function_filter $val_and_op(f, prep, ba, x, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S1in; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, res1, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) + function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res1), ba, x, contexts...) + function_filter $op!(f, mysimilar(res1), ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!(f, mysimilar(res1), ba, x, contexts...) + function_filter $val_and_op!(f, mysimilar(res1), ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res1), prep, ba, x, contexts...) + function_filter $op!(f, mysimilar(res1), prep, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, mysimilar(res1), prep, ba, x, contexts... - ) + function_filter $val_and_op!( + f, mysimilar(res1), prep, ba, x, contexts... + ) return nothing end op == :gradient && continue @eval function test_jet( - ba::AbstractADType, - scen::$S2out; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S2out; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, y, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op( - f, prep_args.y, ba, prep_args.x, prep_args.contexts... - ) + function_filter $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, y, ba, x, contexts...) + function_filter $op(f, y, ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, y, ba, x, contexts...) + function_filter $val_and_op(f, y, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, y, prep, ba, x, contexts...) + function_filter $op(f, y, prep, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, y, prep, ba, x, contexts...) + function_filter $val_and_op(f, y, prep, ba, x, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S2in; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S2in; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, y, res1, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op( - f, prep_args.y, ba, prep_args.x, prep_args.contexts... - ) + function_filter $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, y, mysimilar(res1), ba, x, contexts...) + function_filter $op!(f, y, mysimilar(res1), ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!(f, y, mysimilar(res1), ba, x, contexts...) + function_filter $val_and_op!(f, y, mysimilar(res1), ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, y, mysimilar(res1), prep, ba, x, contexts...) + function_filter $op!(f, y, mysimilar(res1), prep, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, y, mysimilar(res1), prep, ba, x, contexts... - ) + function_filter $val_and_op!( + f, y, mysimilar(res1), prep, ba, x, contexts... + ) return nothing end elseif op in [:second_derivative, :hessian] @eval function test_jet( - ba::AbstractADType, - scen::$S1out; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) + function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, ba, x, contexts...) + function_filter $op(f, ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, ba, x, contexts...) + function_filter $val_and_op(f, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, prep, ba, x, contexts...) + function_filter $op(f, prep, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, prep, ba, x, contexts...) + function_filter $val_and_op(f, prep, ba, x, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S1in; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, res1, res2, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) + function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res2), ba, x, contexts...) + function_filter $op!(f, mysimilar(res2), ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, mysimilar(res1), mysimilar(res2), ba, x, contexts... - ) + function_filter $val_and_op!( + f, mysimilar(res1), mysimilar(res2), ba, x, contexts... + ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res2), prep, ba, x, contexts...) + function_filter $op!(f, mysimilar(res2), prep, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, mysimilar(res1), mysimilar(res2), prep, ba, x, contexts... - ) + function_filter $val_and_op!( + f, mysimilar(res1), mysimilar(res2), prep, ba, x, contexts... + ) return nothing end elseif op in [:pushforward, :pullback] @eval function test_jet( - ba::AbstractADType, - scen::$S1out; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, t, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op( - f, ba, prep_args.x, prep_args.t, prep_args.contexts... - ) + function_filter $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, ba, x, t, contexts...) + function_filter $op(f, ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, ba, x, t, contexts...) + function_filter $val_and_op(f, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, prep, ba, x, t, contexts...) + function_filter $op(f, prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, prep, ba, x, t, contexts...) + function_filter $val_and_op(f, prep, ba, x, t, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S1in; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, t, res1, res2, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op( - f, ba, prep_args.x, prep_args.t, prep_args.contexts... - ) + function_filter $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res1), ba, x, t, contexts...) + function_filter $op!(f, mysimilar(res1), ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!(f, mysimilar(res1), ba, x, t, contexts...) + function_filter $val_and_op!(f, mysimilar(res1), ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res1), prep, ba, x, t, contexts...) + function_filter $op!(f, mysimilar(res1), prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, mysimilar(res1), prep, ba, x, t, contexts... - ) + function_filter $val_and_op!( + f, mysimilar(res1), prep, ba, x, t, contexts... + ) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S2out; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S2out; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, y, t, contexts, prep_args) = deepcopy(scen) prep = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op( - f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... - ) + function_filter $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, y, ba, x, t, contexts...) + function_filter $op(f, y, ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, y, ba, x, t, contexts...) + function_filter $val_and_op(f, y, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, y, prep, ba, x, t, contexts...) + function_filter $op(f, y, prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, y, prep, ba, x, t, contexts...) + function_filter $val_and_op(f, y, prep, ba, x, t, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S2in; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S2in; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, y, t, res1, contexts, prep_args) = deepcopy(scen) prep = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op( - f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... - ) + function_filter $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, y, mysimilar(res1), ba, x, t, contexts...) + function_filter $op!(f, y, mysimilar(res1), ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, y, mysimilar(res1), ba, x, t, contexts... - ) + function_filter $val_and_op!( + f, y, mysimilar(res1), ba, x, t, contexts... + ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, y, mysimilar(res1), prep, ba, x, t, contexts...) + function_filter $op!(f, y, mysimilar(res1), prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, y, mysimilar(res1), prep, ba, x, t, contexts... - ) + function_filter $val_and_op!( + f, y, mysimilar(res1), prep, ba, x, t, contexts... + ) return nothing end elseif op in [:hvp] @eval function test_jet( - ba::AbstractADType, - scen::$S1out; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, t, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op( - f, ba, prep_args.x, prep_args.t, prep_args.contexts... - ) + function_filter $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, ba, x, t, contexts...) + function_filter $op(f, ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, ba, x, t, contexts...) + function_filter $val_and_op(f, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, prep, ba, x, t, contexts...) + function_filter $op(f, prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, prep, ba, x, t, contexts...) + function_filter $val_and_op(f, prep, ba, x, t, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S1in; - subset::Symbol, - ignored_modules, - function_filter, - ) + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, + ) (; f, x, t, res1, res2, contexts, prep_args) = deepcopy(scen) prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op( - f, ba, prep_args.x, prep_args.t, prep_args.contexts... - ) + function_filter $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res2), ba, x, t, contexts...) + function_filter $op!(f, mysimilar(res2), ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, mysimilar(res1), mysimilar(res2), ba, x, t, contexts... - ) + function_filter $val_and_op!( + f, mysimilar(res1), mysimilar(res2), ba, x, t, contexts... + ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res2), prep, ba, x, t, contexts...) + function_filter $op!(f, mysimilar(res2), prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, mysimilar(res1), mysimilar(res2), prep, ba, x, t, contexts... - ) + function_filter $val_and_op!( + f, mysimilar(res1), mysimilar(res2), prep, ba, x, t, contexts... + ) return nothing end end diff --git a/DifferentiationInterfaceTest/src/utils.jl b/DifferentiationInterfaceTest/src/utils.jl index 04a167871..e28ef0646 100644 --- a/DifferentiationInterfaceTest/src/utils.jl +++ b/DifferentiationInterfaceTest/src/utils.jl @@ -1,5 +1,5 @@ mysimilar(x::AbstractArray) = similar(x) -mysimilar(x::Union{Tuple,NamedTuple}) = map(mysimilar, x) +mysimilar(x::Union{Tuple, NamedTuple}) = map(mysimilar, x) mysize(x::Number) = size(x) mysize(x::AbstractArray) = size(x) @@ -7,7 +7,7 @@ mysize(x) = missing mymultiply(x::Number, a::Number) = a * x mymultiply(x::AbstractArray, a::Number) = a .* x -mymultiply(x::Union{Tuple,NamedTuple}, a::Number) = map(Base.Fix2(mymultiply, a), x) +mymultiply(x::Union{Tuple, NamedTuple}, a::Number) = map(Base.Fix2(mymultiply, a), x) mymultiply(::Nothing, a::Number) = nothing mynnz(A::AbstractMatrix) = count(!iszero, A) diff --git a/DifferentiationInterfaceTest/test/formalities.jl b/DifferentiationInterfaceTest/test/formalities.jl index e05731041..862c0b66d 100644 --- a/DifferentiationInterfaceTest/test/formalities.jl +++ b/DifferentiationInterfaceTest/test/formalities.jl @@ -7,12 +7,12 @@ using SparseMatrixColorings: SparseMatrixColorings using Test @testset "Aqua" begin - Aqua.test_all(DifferentiationInterfaceTest; ambiguities=false, undocumented_names=true) + Aqua.test_all(DifferentiationInterfaceTest; ambiguities = false, undocumented_names = true) end @testset verbose = true "JET" begin # until https://github.com/JuliaLang/julia/pull/59321 is released if VERSION <= v"1.12-" - JET.test_package(DifferentiationInterfaceTest; target_defined_modules=true) + JET.test_package(DifferentiationInterfaceTest; target_defined_modules = true) end end diff --git a/DifferentiationInterfaceTest/test/scenario.jl b/DifferentiationInterfaceTest/test/scenario.jl index a07e0739d..194cd53b3 100644 --- a/DifferentiationInterfaceTest/test/scenario.jl +++ b/DifferentiationInterfaceTest/test/scenario.jl @@ -5,17 +5,17 @@ using ForwardDiff: ForwardDiff using Test @testset "Naming" begin - scen = Scenario{:gradient,:out}( - sum, zeros(10); res1=ones(10), name="My pretty little scenario" + scen = Scenario{:gradient, :out}( + sum, zeros(10); res1 = ones(10), name = "My pretty little scenario" ) @test string(scen) == "My pretty little scenario" testset = test_differentiation( - AutoForwardDiff(), [scen]; testset_name="My amazing test set" + AutoForwardDiff(), [scen]; testset_name = "My amazing test set" ) data = benchmark_differentiation( - AutoForwardDiff(), [scen]; testset_name="My amazing test set" + AutoForwardDiff(), [scen]; testset_name = "My amazing test set" ) end; diff --git a/DifferentiationInterfaceTest/test/standard.jl b/DifferentiationInterfaceTest/test/standard.jl index 24e64b0b0..a3240e42e 100644 --- a/DifferentiationInterfaceTest/test/standard.jl +++ b/DifferentiationInterfaceTest/test/standard.jl @@ -12,46 +12,46 @@ LOGGING = get(ENV, "CI", "false") == "false" ## Dense test_differentiation( - [AutoForwardDiff(), AutoForwardDiff(; chunksize=100)], - default_scenarios(; include_smaller=true, include_constantified=true); - logging=LOGGING, + [AutoForwardDiff(), AutoForwardDiff(; chunksize = 100)], + default_scenarios(; include_smaller = true, include_constantified = true); + logging = LOGGING, ) test_differentiation( - [AutoForwardDiff(), AutoFiniteDiff(; relstep=1e-5)], + [AutoForwardDiff(), AutoFiniteDiff(; relstep = 1.0e-5)], default_scenarios(; - include_batchified=false, - include_normal=false, - include_cachified=true, - include_constantorcachified=true, + include_batchified = false, + include_normal = false, + include_cachified = true, + include_constantorcachified = true, ); - logging=LOGGING, + logging = LOGGING, ) test_differentiation( - [AutoForwardDiff()], empty_scenarios(); excluded=[:gradient], logging=LOGGING + [AutoForwardDiff()], empty_scenarios(); excluded = [:gradient], logging = LOGGING ) test_differentiation( - [AutoFiniteDiff()], empty_scenarios(); excluded=[:jacobian], logging=LOGGING + [AutoFiniteDiff()], empty_scenarios(); excluded = [:jacobian], logging = LOGGING ) ## Sparse sparse_backend = AutoSparse( AutoForwardDiff(); - sparsity_detector=TracerSparsityDetector(), - coloring_algorithm=GreedyColoringAlgorithm(), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm(), ) test_differentiation( sparse_backend, - sparse_scenarios(; include_cachified=true, use_tuples=false); - sparsity=true, - logging=LOGGING, + sparse_scenarios(; include_cachified = true, use_tuples = false); + sparsity = true, + logging = LOGGING, ) ## Complex test_differentiation( - AutoFiniteDiff(), vcat(complex_scenarios(), complex_sparse_scenarios()); logging=LOGGING + AutoFiniteDiff(), vcat(complex_scenarios(), complex_sparse_scenarios()); logging = LOGGING ) diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index f5e88e1ff..1a121ce36 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -24,18 +24,18 @@ LOGGING = get(ENV, "CI", "false") == "false" ## Generate all scenarios gpu_scenarios(; - include_constantified=true, - include_closurified=true, - include_batchified=true, - include_cachified=true, - use_tuples=true, + include_constantified = true, + include_closurified = true, + include_batchified = true, + include_cachified = true, + use_tuples = true, ) static_scenarios(; - include_constantified=true, - include_closurified=true, - include_batchified=true, - include_cachified=true, - use_tuples=true, + include_constantified = true, + include_closurified = true, + include_batchified = true, + include_cachified = true, + use_tuples = true, ) ## Weird arrays @@ -43,26 +43,26 @@ static_scenarios(; test_differentiation( AutoForwardDiff(), DIT.no_matrices(static_scenarios()); - benchmark=:prepared, - logging=LOGGING, + benchmark = :prepared, + logging = LOGGING, ) -test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING) +test_differentiation(AutoForwardDiff(), component_scenarios(); logging = LOGGING) -test_differentiation(AutoZygote(), gpu_scenarios(); excluded=SECOND_ORDER, logging=LOGGING) +test_differentiation(AutoZygote(), gpu_scenarios(); excluded = SECOND_ORDER, logging = LOGGING) ## Closures & caches test_differentiation( AutoFiniteDiff(), default_scenarios(; - include_normal=false, - include_closurified=true, - include_cachified=true, - use_tuples=true, + include_normal = false, + include_closurified = true, + include_cachified = true, + use_tuples = true, ); - excluded=SECOND_ORDER, - logging=LOGGING, + excluded = SECOND_ORDER, + logging = LOGGING, ); ## Neural nets @@ -70,19 +70,19 @@ test_differentiation( test_differentiation( AutoZygote(), DIT.flux_scenarios(Random.MersenneTwister(0)); - isapprox=DIT.flux_isapprox, - rtol=1e-2, - atol=1e-4, - scenario_intact=false, - logging=LOGGING, + isapprox = DIT.flux_isapprox, + rtol = 1.0e-2, + atol = 1.0e-4, + scenario_intact = false, + logging = LOGGING, ) test_differentiation( AutoZygote(), DIT.lux_scenarios(Random.Xoshiro(63)); - isapprox=DIT.lux_isapprox, - rtol=1.0f-2, - atol=1.0f-3, - scenario_intact=false, - logging=LOGGING, + isapprox = DIT.lux_isapprox, + rtol = 1.0f-2, + atol = 1.0f-3, + scenario_intact = false, + logging = LOGGING, ) diff --git a/DifferentiationInterfaceTest/test/zero_backends.jl b/DifferentiationInterfaceTest/test/zero_backends.jl index b0931264d..8796d524f 100644 --- a/DifferentiationInterfaceTest/test/zero_backends.jl +++ b/DifferentiationInterfaceTest/test/zero_backends.jl @@ -11,38 +11,38 @@ LOGGING = get(ENV, "CI", "false") == "false" test_differentiation( AutoZeroForward(), - map(zero, default_scenarios(; include_batchified=false)); - type_stability=safetypestab(:full), - logging=LOGGING, - reprepare=false, + map(zero, default_scenarios(; include_batchified = false)); + type_stability = safetypestab(:full), + logging = LOGGING, + reprepare = false, ) test_differentiation( AutoZeroReverse(), map( DifferentiationInterfaceTest.same_function, - default_scenarios(; include_batchified=false), + default_scenarios(; include_batchified = false), ); - correctness=false, - type_stability=safetypestab(:prepared), - logging=LOGGING, + correctness = false, + type_stability = safetypestab(:prepared), + logging = LOGGING, ) ## Benchmark data0 = benchmark_differentiation( AutoZeroForward(), - no_matrices(default_scenarios(; include_batchified=false, include_constantified=true)); - logging=LOGGING, + no_matrices(default_scenarios(; include_batchified = false, include_constantified = true)); + logging = LOGGING, ); data1 = benchmark_differentiation( AutoZeroForward(), - no_matrices(default_scenarios(; include_batchified=false)); - benchmark=:full, - logging=LOGGING, - benchmark_seconds=0.05, - benchmark_aggregation=maximum, + no_matrices(default_scenarios(; include_batchified = false)); + benchmark = :full, + logging = LOGGING, + benchmark_seconds = 0.05, + benchmark_aggregation = maximum, ); struct FakeBackend <: ADTypes.AbstractADType end @@ -50,9 +50,9 @@ ADTypes.mode(::FakeBackend) = ADTypes.ForwardMode() data2 = benchmark_differentiation( FakeBackend(), - no_matrices(default_scenarios(; include_batchified=false)); - logging=false, - benchmark_test=false, + no_matrices(default_scenarios(; include_batchified = false)); + logging = false, + benchmark_test = false, ); @testset "Benchmarking DataFrame" begin @@ -75,16 +75,16 @@ end benchmark_differentiation( AutoZeroForward(), allocfree_scenarios(); - excluded=[:pullback, :gradient], - benchmark=:prepared, - logging=LOGGING, + excluded = [:pullback, :gradient], + benchmark = :prepared, + logging = LOGGING, ), benchmark_differentiation( AutoZeroReverse(), allocfree_scenarios(); - excluded=[:pushforward, :derivative], - benchmark=:prepared, - logging=LOGGING, + excluded = [:pushforward, :derivative], + benchmark = :prepared, + logging = LOGGING, ), ) @testset "$(collect(row[1:4]))" for row in collect(eachrow(data_allocfree)) @@ -95,25 +95,25 @@ end test_differentiation( AutoZeroForward(), allocfree_scenarios(); - correctness=false, - allocations=:prepared, - excluded=[:pullback, :gradient, :jacobian], - logging=LOGGING, + correctness = false, + allocations = :prepared, + excluded = [:pullback, :gradient, :jacobian], + logging = LOGGING, ) test_differentiation( AutoZeroReverse(), allocfree_scenarios(); - correctness=false, - allocations=:prepared, - excluded=[:pushforward, :derivative, :jacobian], - logging=LOGGING, + correctness = false, + allocations = :prepared, + excluded = [:pushforward, :derivative, :jacobian], + logging = LOGGING, ) test_differentiation( AutoZeroForward(); - correctness=false, - allocations=:full, - skip_allocations=true, - logging=LOGGING, + correctness = false, + allocations = :full, + skip_allocations = true, + logging = LOGGING, )