11module ScalarTests
22
3- using ReverseDiff, ForwardDiff, Test, DiffRules, SpecialFunctions, NaNMath
3+ using ReverseDiff
4+
5+ using DiffRules
6+ using ForwardDiff
7+ using LogExpFunctions
8+ using NaNMath
9+ using SpecialFunctions
10+
11+ using Test
412
513include (joinpath (dirname (@__FILE__ ), " ../utils.jl" ))
614
715x, a, b = rand (3 )
816tp = InstructionTape ()
917int_range = 1 : 10
1018
11- function test_forward (f, x, tp:: InstructionTape , is_domain_err_func :: Bool )
19+ function test_forward (f, x, tp:: InstructionTape , fsym :: Symbol )
1220 xt = ReverseDiff. TrackedReal (x, zero (x), tp)
1321 y = f (x)
1422
@@ -23,7 +31,7 @@ function test_forward(f, x, tp::InstructionTape, is_domain_err_func::Bool)
2331 @test deriv (xt) == ForwardDiff. derivative (f, x)
2432
2533 # forward
26- x2 = is_domain_err_func ? rand () + 1 : rand ()
34+ x2 = modify_input (fsym, rand () )
2735 ReverseDiff. value! (xt, x2)
2836 ReverseDiff. forward_pass! (tp)
2937 @test value (yt) == f (x2)
@@ -133,15 +141,15 @@ function test_skip(f, a, b, tp)
133141 @test isempty (tp)
134142end
135143
136- DOMAIN_ERR_FUNCS = (:asec , :acsc , :asecd , :acscd , :acoth , :acosh )
137-
138- for (M, f, arity) in DiffRules. diffrules ()
144+ for (M, f, arity) in DiffRules. diffrules (; filter_modules= nothing )
145+ # ensure that function is defined
146+ if ! (isdefined (@__MODULE__ , M) && isdefined (getfield (@__MODULE__ , M), f))
147+ error (" $M .$f is not available" )
148+ end
139149 f === :rem2pi && continue
140150 if arity == 1
141151 test_println (" forward-mode unary scalar functions" , string (M, " ." , f))
142- is_domain_err_func = in (f, DOMAIN_ERR_FUNCS)
143- n = is_domain_err_func ? x + 1 : x
144- test_forward (eval (:($ M.$ f)), n, tp, is_domain_err_func)
152+ test_forward (eval (:($ M.$ f)), modify_input (f, x), tp, f)
145153 elseif arity == 2
146154 in (f, SKIPPED_BINARY_SCALAR_TESTS) && continue
147155 test_println (" forward-mode binary scalar functions" , f)
@@ -153,7 +161,7 @@ INT_ONLY_FUNCS = (:iseven, :isodd)
153161
154162for f in ReverseDiff. SKIPPED_UNARY_SCALAR_FUNCS
155163 test_println (" SKIPPED_UNARY_SCALAR_FUNCS" , f)
156- n = in (f, DOMAIN_ERR_FUNCS) ? x + 1 : x
164+ n = modify_input (f, x)
157165 n = in (f, INT_ONLY_FUNCS) ? ceil (Int, n) : n
158166 test_skip (eval (f), n, tp)
159167end
0 commit comments