Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUSOLVER] Support symmetric factorization without pivoting #2640

Merged
merged 3 commits into from
Feb 4, 2025

Conversation

amontoison
Copy link
Member

No description provided.

Copy link
Contributor

github-actions bot commented Feb 3, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/cusolver/dense.jl b/lib/cusolver/dense.jl
index 12e853c44..bd92e180d 100644
--- a/lib/cusolver/dense.jl
+++ b/lib/cusolver/dense.jl
@@ -206,7 +206,7 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
             A, ipiv, info
         end
 
-        function sytrf!(uplo::Char, A::StridedCuMatrix{$elty}; pivoting::Bool=true)
+        function sytrf!(uplo::Char, A::StridedCuMatrix{$elty}; pivoting::Bool = true)
             n = checksquare(A)
             if pivoting
                 ipiv = CuArray{Cint}(undef, n)
@@ -225,8 +225,10 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
                 end
 
                 with_workspace(dh.workspace_gpu, bufferSize) do buffer
-                    $fname(dh, uplo, n, A, lda, CU_NULL,
-                           buffer, sizeof(buffer) ÷ sizeof($elty), dh.info)
+                    $fname(
+                        dh, uplo, n, A, lda, CU_NULL,
+                        buffer, sizeof(buffer) ÷ sizeof($elty), dh.info
+                    )
                 end
 
                 info = @allowscalar dh.info[1]
diff --git a/lib/cusolver/dense_generic.jl b/lib/cusolver/dense_generic.jl
index 7f2f1569a..07e9a4305 100644
--- a/lib/cusolver/dense_generic.jl
+++ b/lib/cusolver/dense_generic.jl
@@ -160,20 +160,26 @@ function sytrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}) where
     function bufferSize()
         out_cpu = Ref{Csize_t}(0)
         out_gpu = Ref{Csize_t}(0)
-        cusolverDnXsytrs_bufferSize(dh, uplo, n, nrhs, T, A,
-                                    lda, CU_NULL, T, B, ldb, out_gpu, out_cpu)
-        out_gpu[], out_cpu[]
-    end
-    with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
-                    bufferSize()...) do buffer_gpu, buffer_cpu
-        cusolverDnXsytrs(dh, uplo, n, nrhs, T, A, lda, CU_NULL,
-                         T, B, ldb, buffer_gpu, sizeof(buffer_gpu),
-                         buffer_cpu, sizeof(buffer_cpu), dh.info)
+        cusolverDnXsytrs_bufferSize(
+            dh, uplo, n, nrhs, T, A,
+            lda, CU_NULL, T, B, ldb, out_gpu, out_cpu
+        )
+        return out_gpu[], out_cpu[]
+    end
+    with_workspaces(
+        dh.workspace_gpu, dh.workspace_cpu,
+        bufferSize()...
+    ) do buffer_gpu, buffer_cpu
+        cusolverDnXsytrs(
+            dh, uplo, n, nrhs, T, A, lda, CU_NULL,
+            T, B, ldb, buffer_gpu, sizeof(buffer_gpu),
+            buffer_cpu, sizeof(buffer_cpu), dh.info
+        )
     end
 
     flag = @allowscalar dh.info[1]
     chkargsok(flag |> BlasInt)
-    B
+    return B
 end
 
 # Xtrtri
diff --git a/test/libraries/cusolver/dense_generic.jl b/test/libraries/cusolver/dense_generic.jl
index 003d55e65..a5c27e6a3 100644
--- a/test/libraries/cusolver/dense_generic.jl
+++ b/test/libraries/cusolver/dense_generic.jl
@@ -94,8 +94,8 @@ p = 5
     @testset "sytrs!" begin
         @testset "uplo = $uplo" for uplo in ('L', 'U')
             @testset "pivoting = $pivoting" for pivoting in (false, true)
-                A = rand(elty,n,n)
-                B = rand(elty,n,p)
+                A = rand(elty, n, n)
+                B = rand(elty, n, p)
                 A = A + transpose(A)
                 d_A = CuMatrix(A)
                 d_B = CuMatrix(B)

@maleadt maleadt added enhancement New feature or request cuda libraries Stuff about CUDA library wrappers. labels Feb 4, 2025
@maleadt maleadt merged commit f3b3f8b into JuliaGPU:master Feb 4, 2025
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda libraries Stuff about CUDA library wrappers. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants