@@ -305,7 +305,7 @@ def diamonddist(a, b, mx_basis='pp', return_x=False):
305
305
mx_basis = _bt .create_basis_for_matrix (a , mx_basis )
306
306
307
307
# currently cvxpy is only needed for this function, so don't import until here
308
- import cvxpy as _cvxpy
308
+ import cvxpy as _cp
309
309
310
310
# _jam code below assumes *un-normalized* Jamiol-isomorphism.
311
311
# It will convert a & b to a "single-block" basis representation
@@ -322,19 +322,29 @@ def diamonddist(a, b, mx_basis='pp', return_x=False):
322
322
J = JBstd - JAstd
323
323
prob , vars = _diamond_norm_model (dim , smallDim , J )
324
324
325
- try :
326
- prob .solve (solver = 'Clarabel' )
327
- except _cvxpy .error .SolverError as e :
328
- _warnings .warn ("CVXPY failed: %s - diamonddist returning -2!" % str (e ))
329
- return (- 2 , _np .zeros ((dim , dim ))) if return_x else - 2
330
- except :
331
- _warnings .warn ("CVXOPT failed (unknown err) - diamonddist returning -2!" )
332
- return (- 2 , _np .zeros ((dim , dim ))) if return_x else - 2
325
+ objective_val = - 2
326
+ varvals = [_np .zeros_like (J ), None , None ]
327
+ sdp_solvers = ['MOSEK' , 'CLARABEL' , 'CVXOPT' ]
328
+ for i , solver in enumerate (sdp_solvers ):
329
+ try :
330
+ prob .solve (solver = solver )
331
+ objective_val = prob .value
332
+ varvals = [v .value for v in vars ]
333
+ break
334
+ except (AssertionError , _cp .SolverError ) as e :
335
+ if solver != 'MOSEK' :
336
+ msg = f"Received error { e } when trying to use solver={ solver } ."
337
+ if i + 1 == len (sdp_solvers ):
338
+ failure_msg = "Out of solvers. Returning -2 for diamonddist."
339
+ else :
340
+ failure_msg = f"Trying { sdp_solvers [i + 1 ]} next."
341
+ msg += f'\n { failure_msg } '
342
+ _warnings .warn (msg )
333
343
334
344
if return_x :
335
- return prob . value , vars [ 0 ]. value
345
+ return objective_val , varvals
336
346
else :
337
- return prob . value
347
+ return objective_val
338
348
339
349
340
350
def _diamond_norm_model (dim , smallDim , J ):
0 commit comments