Skip to content

Commit 995f7ab

Browse files
author
Corey Ostrove
committed
Update solver selection
Add in multiple solver options that are sequentially tried before fully failing. Also turn back on diamond distance tests on windows which were previously being skipped.
1 parent 6d98a2a commit 995f7ab

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

pygsti/tools/optools.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def diamonddist(a, b, mx_basis='pp', return_x=False):
305305
mx_basis = _bt.create_basis_for_matrix(a, mx_basis)
306306

307307
# currently cvxpy is only needed for this function, so don't import until here
308-
import cvxpy as _cvxpy
308+
import cvxpy as _cp
309309

310310
# _jam code below assumes *un-normalized* Jamiol-isomorphism.
311311
# It will convert a & b to a "single-block" basis representation
@@ -322,19 +322,29 @@ def diamonddist(a, b, mx_basis='pp', return_x=False):
322322
J = JBstd - JAstd
323323
prob, vars = _diamond_norm_model(dim, smallDim, J)
324324

325-
try:
326-
prob.solve(solver='Clarabel')
327-
except _cvxpy.error.SolverError as e:
328-
_warnings.warn("CVXPY failed: %s - diamonddist returning -2!" % str(e))
329-
return (-2, _np.zeros((dim, dim))) if return_x else -2
330-
except:
331-
_warnings.warn("CVXOPT failed (unknown err) - diamonddist returning -2!")
332-
return (-2, _np.zeros((dim, dim))) if return_x else -2
325+
objective_val = -2
326+
varvals = [_np.zeros_like(J), None, None]
327+
sdp_solvers = ['MOSEK', 'CLARABEL', 'CVXOPT']
328+
for i, solver in enumerate(sdp_solvers):
329+
try:
330+
prob.solve(solver=solver)
331+
objective_val = prob.value
332+
varvals = [v.value for v in vars]
333+
break
334+
except (AssertionError, _cp.SolverError) as e:
335+
if solver != 'MOSEK':
336+
msg = f"Received error {e} when trying to use solver={solver}."
337+
if i + 1 == len(sdp_solvers):
338+
failure_msg = "Out of solvers. Returning -2 for diamonddist."
339+
else:
340+
failure_msg = f"Trying {sdp_solvers[i+1]} next."
341+
msg += f'\n{failure_msg}'
342+
_warnings.warn(msg)
333343

334344
if return_x:
335-
return prob.value, vars[0].value
345+
return objective_val, varvals
336346
else:
337-
return prob.value
347+
return objective_val
338348

339349

340350
def _diamond_norm_model(dim, smallDim, J):

test/unit/tools/test_optools.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from pygsti.modelpacks.legacy import std2Q_XXYYII
1515
from ..util import BaseCase, needs_cvxpy
1616

17-
SKIP_DIAMONDIST_ON_WIN = True
18-
1917

2018
def fake_minimize(fn):
2119
"""Mock scipy.optimize.minimize in the underlying function call to reduce optimization overhead"""
@@ -387,7 +385,6 @@ def test_jtrace_distance(self):
387385

388386
@needs_cvxpy
389387
def test_diamond_distance(self):
390-
if SKIP_DIAMONDIST_ON_WIN and sys.platform.startswith('win'): return
391388
val = ot.diamonddist(self.A_TP, self.A_TP, mx_basis="pp")
392389
self.assertAlmostEqual(val, 0.0)
393390
val = ot.diamonddist(self.A_TP, self.B_unitary, mx_basis="pp")

0 commit comments

Comments
 (0)