@@ -381,7 +381,7 @@ def mmread(source, parallelism=None, long_type=False):
381
381
382
382
383
383
def mmwrite (target , a , comment = None , field = None , precision = None , symmetry = "AUTO" ,
384
- parallelism = None , find_symmetry = False ):
384
+ parallelism = None , find_symmetry = False ):
385
385
"""
386
386
Write a matrix to a MatrixMarket file or file-like object.
387
387
@@ -428,20 +428,51 @@ def mmwrite(target, a, comment=None, field=None, precision=None, symmetry="AUTO"
428
428
_core .write_body_array (cursor , a )
429
429
return
430
430
431
- if scipy .sparse .isspmatrix (a ):
431
+ # handle both scipy.sparse.*_matrix and scipy.sparse.*_array
432
+ # Both have the same interface as far as this method is concerned, so let duck typing do its thing.
433
+ # Support for these types varies between scipy versions, so attempt to support all possibilities.
434
+ is_sparse = False
435
+ is_compressed = False
436
+ coo_type = None
437
+ csr_types = []
438
+
439
+ # check for *_matrix
440
+ try :
441
+ if scipy .sparse .isspmatrix (a ):
442
+ is_sparse = True
443
+ from scipy .sparse import coo_matrix
444
+ coo_type = coo_matrix
445
+ # CSC and CSR have specialized writers.
446
+ is_compressed = (isinstance (a , scipy .sparse .csc_matrix ) or isinstance (a , scipy .sparse .csr_matrix ))
447
+ csr_types .append (scipy .sparse .csr_matrix )
448
+ except ImportError :
449
+ pass
450
+
451
+ # check for *_array
452
+ try :
453
+ if scipy .sparse .issparse (a ):
454
+ is_sparse = True
455
+ from scipy .sparse import coo_array
456
+ coo_type = coo_array
457
+ # CSC and CSR have specialized writers. The type may already be a cs*_matrix.
458
+ is_compressed = is_compressed or \
459
+ (isinstance (a , scipy .sparse .csc_array ) or isinstance (a , scipy .sparse .csr_array ))
460
+ csr_types .append (scipy .sparse .csr_array )
461
+ except ImportError :
462
+ pass
463
+
464
+ if is_sparse :
432
465
# Write sparse scipy matrices
433
466
if symmetry is not None and symmetry != "general" :
434
467
# A symmetric matrix only specifies the elements below the diagonal.
435
468
# Ensure that the matrix satisfies this requirement.
436
- from scipy . sparse import coo_matrix
469
+
437
470
a = a .tocoo ()
438
471
lower_triangle_mask = a .row >= a .col
439
- a = coo_matrix ((a .data [lower_triangle_mask ],
440
- (a .row [lower_triangle_mask ],
441
- a .col [lower_triangle_mask ])), shape = a .shape )
442
-
443
- # CSC and CSR have specialized writers.
444
- is_compressed = (isinstance (a , scipy .sparse .csc_matrix ) or isinstance (a , scipy .sparse .csr_matrix ))
472
+ a = coo_type ((a .data [lower_triangle_mask ],
473
+ (a .row [lower_triangle_mask ],
474
+ a .col [lower_triangle_mask ])), shape = a .shape )
475
+ is_compressed = False
445
476
446
477
if not is_compressed :
447
478
# convert everything except CSC/CSR to coo
@@ -451,7 +482,7 @@ def mmwrite(target, a, comment=None, field=None, precision=None, symmetry="AUTO"
451
482
452
483
if is_compressed :
453
484
# CSC and CSR can be written directly
454
- is_csr = isinstance (a , scipy . sparse . csr_matrix )
485
+ is_csr = any ([ isinstance (a , t ) for t in csr_types ] )
455
486
_core .write_body_csc (cursor , a .shape , a .indptr , a .indices , data , is_csr )
456
487
else :
457
488
_core .write_body_coo (cursor , a .shape , a .row , a .col , data )
0 commit comments