Skip to content

Commit 0551195

Browse files
committed
Improve handling of scipy.sparse.*_array in mmwrite()
1 parent 4e4c7bd commit 0551195

File tree

1 file changed

+41
-10
lines changed

1 file changed

+41
-10
lines changed

python/src/fast_matrix_market/__init__.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def mmread(source, parallelism=None, long_type=False):
381381

382382

383383
def mmwrite(target, a, comment=None, field=None, precision=None, symmetry="AUTO",
384-
parallelism=None, find_symmetry=False):
384+
parallelism=None, find_symmetry=False):
385385
"""
386386
Write a matrix to a MatrixMarket file or file-like object.
387387
@@ -428,20 +428,51 @@ def mmwrite(target, a, comment=None, field=None, precision=None, symmetry="AUTO"
428428
_core.write_body_array(cursor, a)
429429
return
430430

431-
if scipy.sparse.isspmatrix(a):
431+
# handle both scipy.sparse.*_matrix and scipy.sparse.*_array
432+
# Both have the same interface as far as this method is concerned, so let duck typing do its thing.
433+
# Support for these types varies between scipy versions, so attempt to support all possibilities.
434+
is_sparse = False
435+
is_compressed = False
436+
coo_type = None
437+
csr_types = []
438+
439+
# check for *_matrix
440+
try:
441+
if scipy.sparse.isspmatrix(a):
442+
is_sparse = True
443+
from scipy.sparse import coo_matrix
444+
coo_type = coo_matrix
445+
# CSC and CSR have specialized writers.
446+
is_compressed = (isinstance(a, scipy.sparse.csc_matrix) or isinstance(a, scipy.sparse.csr_matrix))
447+
csr_types.append(scipy.sparse.csr_matrix)
448+
except ImportError:
449+
pass
450+
451+
# check for *_array
452+
try:
453+
if scipy.sparse.issparse(a):
454+
is_sparse = True
455+
from scipy.sparse import coo_array
456+
coo_type = coo_array
457+
# CSC and CSR have specialized writers. The type may already be a cs*_matrix.
458+
is_compressed = is_compressed or \
459+
(isinstance(a, scipy.sparse.csc_array) or isinstance(a, scipy.sparse.csr_array))
460+
csr_types.append(scipy.sparse.csr_array)
461+
except ImportError:
462+
pass
463+
464+
if is_sparse:
432465
# Write sparse scipy matrices
433466
if symmetry is not None and symmetry != "general":
434467
# A symmetric matrix only specifies the elements below the diagonal.
435468
# Ensure that the matrix satisfies this requirement.
436-
from scipy.sparse import coo_matrix
469+
437470
a = a.tocoo()
438471
lower_triangle_mask = a.row >= a.col
439-
a = coo_matrix((a.data[lower_triangle_mask],
440-
(a.row[lower_triangle_mask],
441-
a.col[lower_triangle_mask])), shape=a.shape)
442-
443-
# CSC and CSR have specialized writers.
444-
is_compressed = (isinstance(a, scipy.sparse.csc_matrix) or isinstance(a, scipy.sparse.csr_matrix))
472+
a = coo_type((a.data[lower_triangle_mask],
473+
(a.row[lower_triangle_mask],
474+
a.col[lower_triangle_mask])), shape=a.shape)
475+
is_compressed = False
445476

446477
if not is_compressed:
447478
# convert everything except CSC/CSR to coo
@@ -451,7 +482,7 @@ def mmwrite(target, a, comment=None, field=None, precision=None, symmetry="AUTO"
451482

452483
if is_compressed:
453484
# CSC and CSR can be written directly
454-
is_csr = isinstance(a, scipy.sparse.csr_matrix)
485+
is_csr = any([isinstance(a, t) for t in csr_types])
455486
_core.write_body_csc(cursor, a.shape, a.indptr, a.indices, data, is_csr)
456487
else:
457488
_core.write_body_coo(cursor, a.shape, a.row, a.col, data)

0 commit comments

Comments
 (0)