Skip to content

Commit 7545358

Browse files
committed
[Utilities] add distance_to_set for more sets
1 parent c2cef34 commit 7545358

2 files changed

Lines changed: 89 additions & 4 deletions

File tree

src/Utilities/distance_to_set.jl

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ function distance_to_set(
553553
MOI.PositiveSemidefiniteConeSquare,
554554
MOI.PositiveSemidefiniteConeTriangle,
555555
},
556-
) where {T<:Real}
556+
) where {T<:Union{Real,Complex}}
557557
_check_dimension(x, set)
558558
# We should return the norm of `A` defined by:
559559
# ```julia
@@ -565,8 +565,13 @@ function distance_to_set(
565565
# The norm should correspond to `MOI.Utilities.set_dot` so it's the
566566
# Frobenius norm, which is the Euclidean norm of the vector of eigenvalues.
567567
eigvals = LinearAlgebra.eigvals(_reshape(x, set))
568-
eigvals .= min.(zero(T), eigvals)
569-
return LinearAlgebra.norm(eigvals, 2)
568+
return LinearAlgebra.norm(_drop_positives.(eigvals), 2)
569+
end
570+
571+
_drop_positives(x::T) where {T<:Real} = min(zero(T), x)
572+
573+
function _drop_positives(x::T)::T where {T<:Complex}
574+
return ifelse(isreal(x), _drop_positives(real(x)), x)
570575
end
571576

572577
"""
@@ -702,3 +707,44 @@ function distance_to_set(
702707
push!(eigvals_neg, max(x[1] - x[2] * sum(log.(eigvals_pos)), zero(T)))
703708
return LinearAlgebra.norm(eigvals_neg, 2)
704709
end
710+
711+
"""
712+
distance_to_set(
713+
::ProjectionUpperBoundDistance,
714+
x::AbstractVector,
715+
set::MOI.Scaled{S},
716+
)
717+
718+
This is the distance in the un-scaled space.
719+
"""
720+
function distance_to_set(
721+
dist::ProjectionUpperBoundDistance,
722+
x::AbstractVector{T},
723+
set::MOI.Scaled{S},
724+
) where {T,S<:MOI.AbstractVectorSet}
725+
_check_dimension(x, set)
726+
scale = MOI.Utilities.SetDotScalingVector{T}(set.set)
727+
return distance_to_set(dist, x ./ scale, set.set)
728+
end
729+
730+
function distance_to_set(
731+
dist::ProjectionUpperBoundDistance,
732+
x::AbstractVector{T},
733+
set::MOI.HermitianPositiveSemidefiniteConeTriangle,
734+
) where {T<:Real}
735+
_check_dimension(x, set)
736+
output_set = MOI.PositiveSemidefiniteConeTriangle(set.side_dimension)
737+
y = zeros(Complex{T}, MOI.dimension(output_set))
738+
real_offset, imag_offset = 0, length(y)
739+
for col in 1:set.side_dimension
740+
for row in 1:col
741+
real_offset += 1
742+
y[real_offset] = x[real_offset]
743+
if row != col
744+
imag_offset += 1
745+
y[real_offset] += x[imag_offset] * im
746+
end
747+
end
748+
end
749+
return distance_to_set(dist, y, output_set)
750+
end

test/Utilities/distance_to_set.jl

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ function test_positivesemidefiniteconesquare()
320320
[1.0, 0.0, 0.0, 1.0] => 0.0,
321321
[1.0, -1.0, -1.0, 1.0] => 0.0,
322322
[1.0, -2.0, -2.0, 1.0] => 1.0,
323-
[1.0, 1.1, 1.1, -2.3] => 2.633053201505194;
323+
[1.0, 1.1, 1.1, -2.3] => 2.633053201505194,
324+
[1.0, -2.0, -2.0, 1.0] => 1.0,;
324325
mismatch = [1.0],
325326
)
326327
return
@@ -476,6 +477,44 @@ function test_LogDetConeSquare()
476477
return
477478
end
478479

480+
function test_Scaled()
481+
_test_set(
482+
MOI.Scaled(MOI.PositiveSemidefiniteConeTriangle(2)),
483+
[1.0, 0.0, 1.0] => 0.0,
484+
[1.0, -1.0, 1.0] => 0.0,
485+
[1.0, -2.0 * sqrt(2), 1.0] => 1.0,
486+
[1.0, 1.1 * sqrt(2), -2.3] => 2.633053201505194;
487+
mismatch = [1.0],
488+
)
489+
return
490+
end
491+
492+
function test_PositiveSemidefiniteConeTriangle_Complex()
493+
_test_set(
494+
MOI.PositiveSemidefiniteConeTriangle(2),
495+
ComplexF64[1.0, 0.0, 1.0] => 0.0,
496+
ComplexF64[1.0, -1.0, 1.0] => 0.0,
497+
ComplexF64[1.0, -2.0, 1.0] => 1.0,
498+
ComplexF64[1.0, 1.1, -2.3] => 2.633053201505194,
499+
ComplexF64[1.0, 1-im, 1.0] => 2.449489742783177;
500+
mismatch = [1.0],
501+
)
502+
return
503+
end
504+
505+
function test_HermitianPositiveSemidefiniteConeTriangle()
506+
_test_set(
507+
MOI.HermitianPositiveSemidefiniteConeTriangle(2),
508+
[1.0, 0.0, 1.0, 0.0] => 0.0,
509+
[1.0, -1.0, 1.0, 0.0] => 0.0,
510+
[1.0, -2.0, 1.0, 0.0] => 1.0,
511+
[1.0, 1.1, -2.3, 0.0] => 2.633053201505194,
512+
[1.0, 1.0, 1.0, -1.0] => 2.449489742783177;
513+
mismatch = [1.0],
514+
)
515+
return
516+
end
517+
479518
end
480519

481520
TestFeasibilityChecker.runtests()

0 commit comments

Comments
 (0)