Skip to content

Commit 61947d9

Browse files
simonbyrnegiordano
andauthored
combine Consts and API modules (#650)
* combine Consts and API modules * @t-bltg's fixes * Builds docs only for functions in `API` module Co-authored-by: Mosè Giordano <[email protected]>
1 parent 5a9ed4e commit 61947d9

32 files changed

+319
-322
lines changed

docs/src/reference/advanced.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,5 @@ MPI.set_errorhandler!
4545
## Miscellaneous
4646

4747
```@docs
48-
MPI.Consts.@const_ref
48+
MPI.API.@const_ref
4949
```

docs/src/reference/api.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33
```@autodocs
44
Modules = [MPI.API]
5+
Order = [:function]
56
```

gen/src/MPIgenerator.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ module MPIgenerator
8080
end
8181
write(src, join(lines, "\n"))
8282

83-
dst = normpath(@__DIR__, "..", "..", "src", "auto_generated_api.jl")
83+
dst = normpath(@__DIR__, "..", "..", "src", "api", "generated_api.jl")
8484
mv(src, dst; force=true) # move the generated file to src
8585
rm(out) # cleanup
8686

src/MPI.jl

+18-45
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,31 @@ function deserialize(x)
1717
Serialization.deserialize(s)
1818
end
1919

20-
primitive type SentinelPtr Sys.WORD_SIZE
21-
end
22-
23-
primitive type MPIPtr Sys.WORD_SIZE
24-
end
25-
@assert sizeof(MPIPtr) == sizeof(Ptr{Cvoid})
26-
Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x
27-
Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x)
28-
2920

3021
function _doc_external(fname)
3122
"""
3223
- `$fname` man page: [OpenMPI](https://www.open-mpi.org/doc/current/man3/$fname.3.php), [MPICH](https://www.mpich.org/static/docs/latest/www3/$fname.html)
3324
"""
3425
end
3526

27+
"""
28+
MPIError
3629
37-
import MPIPreferences
38-
39-
if MPIPreferences.binary == "MPICH_jll"
40-
import MPICH_jll: libmpi, libmpi_handle, mpiexec
41-
const libmpiconstants = nothing
42-
elseif MPIPreferences.binary == "OpenMPI_jll"
43-
import OpenMPI_jll: libmpi, libmpi_handle, mpiexec
44-
const libmpiconstants = nothing
45-
elseif MPIPreferences.binary == "MicrosoftMPI_jll"
46-
import MicrosoftMPI_jll: libmpi, libmpi_handle, mpiexec
47-
const libmpiconstants = nothing
48-
elseif MPIPreferences.binary == "MPItrampoline_jll"
49-
import MPItrampoline_jll: MPItrampoline_jll, libmpi, libmpi_handle, mpiexec
50-
const libmpiconstants = MPItrampoline_jll.libload_time_mpi_constants_path
51-
elseif MPIPreferences.binary == "system"
52-
import MPIPreferences.System: libmpi, libmpi_handle, mpiexec
53-
const libmpiconstants = nothing
54-
else
55-
error("Unknown MPI binary: $(MPIPreferences.binary)")
30+
Error thrown when an MPI function returns an error code. The `code` field contains the MPI error code.
31+
"""
32+
struct MPIError <: Exception
33+
code::Cint
5634
end
35+
function Base.show(io::IO, err::MPIError)
36+
print(io, "MPIError(", err.code, "): ", error_string(err))
37+
end
38+
5739

58-
include("consts/consts.jl")
59-
using .Consts
40+
41+
42+
include("api/api.jl")
43+
using .API
44+
const Consts = API
6045

6146
# These functions are run after reading the values of the constants above)
6247
const _mpi_load_time_hooks = Any[]
@@ -73,21 +58,9 @@ function run_load_time_hooks()
7358
nothing
7459
end
7560

61+
using MPIPreferences
7662
include("implementations.jl")
7763
include("error.jl")
78-
79-
module API
80-
import ..libmpi, ..libmpi_handle, ..MPIPtr
81-
import ..use_stdcall, ..MPIError, ..@mpicall, ..@mpichk
82-
using ..Consts
83-
84-
for name in filter(n -> startswith(string(n), "MPI_"), names(Consts; all = true))
85-
@eval $name = Consts.$name # signatures need types
86-
end
87-
88-
include("auto_generated_api.jl")
89-
end
90-
9164
include("info.jl")
9265
include("group.jl")
9366
include("comm.jl")
@@ -140,7 +113,7 @@ function __init__()
140113

141114
# Needs to be called after `dlopen`. Use `invokelatest` so that `cglobal`
142115
# calls don't trigger early `dlopen`-ing of the library.
143-
Base.invokelatest(Consts.init_consts)
116+
Base.invokelatest(API.init_consts)
144117

145118
# disable UCX memory cache, since it doesn't work correctly
146119
# https://github.com/openucx/ucx/issues/5061
@@ -157,7 +130,7 @@ function __init__()
157130
end
158131

159132
if MPIPreferences.binary == "MPItrampoline_jll" && !haskey(ENV, "MPITRAMPOLINE_MPIEXEC")
160-
ENV["MPITRAMPOLINE_MPIEXEC"] = MPItrampoline_jll.mpich_mpiexec_path
133+
ENV["MPITRAMPOLINE_MPIEXEC"] = API.MPItrampoline_jll.mpich_mpiexec_path
161134
end
162135

163136
run_load_time_hooks()

src/api/api.jl

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
module API
2+
3+
export MPI_Aint, MPI_Count, MPI_Offset, MPI_Status,
4+
MPI_Comm, MPI_Datatype, MPI_Errhandler, MPI_File, MPI_Group,
5+
MPI_Info, MPI_Message, MPI_Op, MPI_Request, MPI_Win,
6+
libmpi, mpiexec, @mpichk, @mpicall, MPIPtr, SentinelPtr, FeatureLevelError
7+
8+
import MPIPreferences
9+
using Libdl
10+
11+
if MPIPreferences.binary == "MPICH_jll"
12+
import MPICH_jll: libmpi, libmpi_handle, mpiexec
13+
const libmpiconstants = nothing
14+
elseif MPIPreferences.binary == "OpenMPI_jll"
15+
import OpenMPI_jll: libmpi, libmpi_handle, mpiexec
16+
const libmpiconstants = nothing
17+
elseif MPIPreferences.binary == "MicrosoftMPI_jll"
18+
import MicrosoftMPI_jll: libmpi, libmpi_handle, mpiexec
19+
const libmpiconstants = nothing
20+
elseif MPIPreferences.binary == "MPItrampoline_jll"
21+
import MPItrampoline_jll: MPItrampoline_jll, libmpi, libmpi_handle, mpiexec
22+
const libmpiconstants = MPItrampoline_jll.libload_time_mpi_constants_path
23+
elseif MPIPreferences.binary == "system"
24+
import MPIPreferences.System: libmpi, libmpi_handle, mpiexec
25+
const libmpiconstants = nothing
26+
else
27+
error("Unknown MPI binary: $(MPIPreferences.binary)")
28+
end
29+
30+
import ..MPIError
31+
const initexprs = Any[]
32+
33+
"""
34+
@const_ref name T expr
35+
36+
Defines an constant binding
37+
```julia
38+
const name = Ref{T}()
39+
```
40+
and adds a hook to execute
41+
```julia
42+
name[] = expr
43+
```
44+
at module initialization time.
45+
"""
46+
macro const_ref(name, T, expr)
47+
push!(initexprs, :($name[] = $expr))
48+
:(const $(esc(name)) = Ref{$T}())
49+
end
50+
51+
@static if MPIPreferences.abi == "MPICH"
52+
include("mpich.jl")
53+
elseif MPIPreferences.abi == "OpenMPI"
54+
include("openmpi.jl")
55+
elseif MPIPreferences.abi == "MicrosoftMPI"
56+
include("microsoftmpi.jl")
57+
elseif MPIPreferences.abi == "MPItrampoline"
58+
include("mpitrampoline.jl")
59+
elseif MPIPreferences.abi == "HPE MPT"
60+
include("mpt.jl")
61+
else
62+
error("Unknown MPI ABI $(MPIPreferences.abi)")
63+
end
64+
65+
primitive type SentinelPtr Sys.WORD_SIZE
66+
end
67+
68+
primitive type MPIPtr Sys.WORD_SIZE
69+
end
70+
@assert sizeof(MPIPtr) == sizeof(Ptr{Cvoid})
71+
Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x
72+
Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x)
73+
74+
75+
# Initialize the ref constants from the library.
76+
# This is not `API.__init__`, as it should be called _after_
77+
# `dlopen` to ensure the library is opened correctly.
78+
@eval function init_consts()
79+
$(Expr(:block, initexprs...))
80+
end
81+
82+
const use_stdcall = startswith(basename(libmpi), "msmpi")
83+
84+
macro mpicall(expr)
85+
@assert expr isa Expr && expr.head == :call && expr.args[1] == :ccall
86+
87+
# On unix systems we call the global symbols to allow for LD_PRELOAD interception
88+
# It can be emulated in Windows (via Libdl.dllist), but this is not fast.
89+
if Sys.isunix() && expr.args[2].head == :tuple &&
90+
(VERSION v"1.5-" || expr.args[2].args[1] :(:MPI_Get_library_version))
91+
expr.args[2] = expr.args[2].args[1]
92+
end
93+
94+
# Microsoft MPI uses stdcall calling convention
95+
# this only affects 32-bit Windows
96+
# unfortunately we need to use ccall to call Get_library_version
97+
# so check using library name instead
98+
if use_stdcall
99+
insert!(expr.args, 3, :stdcall)
100+
end
101+
return esc(expr)
102+
end
103+
104+
"""
105+
FeatureLevelError
106+
107+
Error thrown if a feature is not implemented in the current MPI backend.
108+
"""
109+
struct FeatureLevelError <: Exception
110+
function_name::Symbol
111+
min_version::VersionNumber # minimal MPI version required for this feature to be available
112+
end
113+
function Base.show(io::IO, err::FeatureLevelError)
114+
print(io, "FeatureLevelError($(err.function_name)): Minimum MPI version is $(err.min_version)")
115+
end
116+
117+
macro mpichk(expr, min_version=nothing)
118+
if !isnothing(min_version) && expr.args[2].head == :tuple
119+
fn = expr.args[2].args[1].value
120+
if isnothing(dlsym(libmpi_handle, fn; throw_error=false))
121+
return quote
122+
throw(FeatureLevelError($(QuoteNode(fn)), $min_version))
123+
end
124+
end
125+
end
126+
127+
expr = macroexpand(@__MODULE__, :(@mpicall($expr)))
128+
# MPI_SUCCESS is defined to be 0
129+
:((errcode = $(esc(expr))) == 0 || throw(MPIError(errcode)))
130+
end
131+
132+
133+
include("generated_api.jl")
134+
135+
# since this is called by invokelatest, it isn't automatically precompiled
136+
precompile(init_consts, ())
137+
138+
end
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

src/consts/mpt.jl src/api/mpt.jl

File renamed without changes.
File renamed without changes.

src/buffers.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ MPIPtr
5454

5555
struct InPlace
5656
end
57-
Base.cconvert(::Type{MPIPtr}, ::InPlace) = Consts.MPI_IN_PLACE[]
57+
Base.cconvert(::Type{MPIPtr}, ::InPlace) = API.MPI_IN_PLACE[]
5858

5959

6060
"""

src/collective.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ If only one buffer `sendrecvbuf` is used, then data is overwritten.
464464
$(_doc_external("MPI_Alltoall"))
465465
"""
466466
function Alltoall!(sendbuf::UBuffer, recvbuf::UBuffer, comm::Comm)
467-
if sendbuf.data !== Consts.MPI_IN_PLACE[] && sendbuf.nchunks !== nothing
467+
if sendbuf.data !== API.MPI_IN_PLACE[] && sendbuf.nchunks !== nothing
468468
@assert sendbuf.nchunks >= Comm_size(comm)
469469
end
470470
if recvbuf.nchunks !== nothing
@@ -521,7 +521,7 @@ Similar to [`Alltoall!`](@ref), except with different size chunks per process.
521521
$(_doc_external("MPI_Alltoallv"))
522522
"""
523523
function Alltoallv!(sendbuf::VBuffer, recvbuf::VBuffer, comm::Comm)
524-
if sendbuf.data !== Consts.MPI_IN_PLACE[]
524+
if sendbuf.data !== API.MPI_IN_PLACE[]
525525
@assert length(sendbuf.counts) >= Comm_size(comm)
526526
end
527527
@assert length(recvbuf.counts) >= Comm_size(comm)

src/comm.jl

+11-11
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,25 @@ Base.unsafe_convert(::Type{MPI_Comm}, comm::Comm) = comm.val
1212
Base.unsafe_convert(::Type{Ptr{MPI_Comm}}, comm::Comm) = convert(Ptr{MPI_Comm}, pointer_from_objref(comm))
1313

1414

15-
const COMM_NULL = Comm(Consts.MPI_COMM_NULL[])
16-
add_load_time_hook!(() -> COMM_NULL.val = Consts.MPI_COMM_NULL[])
15+
const COMM_NULL = Comm(API.MPI_COMM_NULL[])
16+
add_load_time_hook!(() -> COMM_NULL.val = API.MPI_COMM_NULL[])
1717

1818
"""
1919
MPI.COMM_WORLD
2020
2121
A communicator containing all processes with which the local rank can communicate at
2222
initialization. In a typical "static-process" model, this will be all processes.
2323
"""
24-
const COMM_WORLD = Comm(Consts.MPI_COMM_WORLD[])
25-
add_load_time_hook!(() -> COMM_WORLD.val = Consts.MPI_COMM_WORLD[])
24+
const COMM_WORLD = Comm(API.MPI_COMM_WORLD[])
25+
add_load_time_hook!(() -> COMM_WORLD.val = API.MPI_COMM_WORLD[])
2626

2727
"""
2828
MPI.COMM_SELF
2929
3030
A communicator containing only the local process.
3131
"""
32-
const COMM_SELF = Comm(Consts.MPI_COMM_SELF[])
33-
add_load_time_hook!(() -> COMM_SELF.val = Consts.MPI_COMM_SELF[])
32+
const COMM_SELF = Comm(API.MPI_COMM_SELF[])
33+
add_load_time_hook!(() -> COMM_SELF.val = API.MPI_COMM_SELF[])
3434

3535
Comm() = Comm(COMM_NULL.val)
3636

@@ -173,7 +173,7 @@ $(_doc_external("MPI_Comm_split"))
173173
"""
174174
function Comm_split(comm::Comm, color::Union{Integer, Nothing}, key::Integer)
175175
if isnothing(color)
176-
color = Consts.MPI_UNDEFINED[]
176+
color = API.MPI_UNDEFINED[]
177177
end
178178
newcomm = Comm()
179179
API.MPI_Comm_split(comm, color, key, newcomm)
@@ -185,7 +185,7 @@ mutable struct SplitType
185185
val::Cint
186186
end
187187
const COMM_TYPE_SHARED = SplitType(-1)
188-
add_load_time_hook!(() -> COMM_TYPE_SHARED.val = Consts.MPI_COMM_TYPE_SHARED[])
188+
add_load_time_hook!(() -> COMM_TYPE_SHARED.val = API.MPI_COMM_TYPE_SHARED[])
189189

190190

191191
"""
@@ -205,7 +205,7 @@ $(_doc_external("MPI_Comm_split_type"))
205205
"""
206206
function Comm_split_type(comm::Comm, split_type, key::Integer; kwargs...)
207207
if isnothing(split_type)
208-
split_type = Consts.MPI_UNDEFINED[]
208+
split_type = API.MPI_UNDEFINED[]
209209
elseif split_type isa SplitType
210210
split_type = split_type.val
211211
end
@@ -276,7 +276,7 @@ The total number of available slots, or `nothing` if it is not defined. This is
276276
This is typically dependent on the MPI implementation: for MPICH-based implementations, this is specified by the `-usize` argument. OpenMPI defines a default value based on the number of processes available.
277277
"""
278278
function universe_size()
279-
ptr = unsafe_get_attr(COMM_WORLD, Consts.MPI_UNIVERSE_SIZE[])
279+
ptr = unsafe_get_attr(COMM_WORLD, API.MPI_UNIVERSE_SIZE[])
280280
isnothing(ptr) && return nothing
281281
return Int(unsafe_load(Ptr{Cint}(ptr)))
282282
end
@@ -287,7 +287,7 @@ end
287287
The maximum value tag value for point-to-point operations.
288288
"""
289289
function tag_ub()
290-
ptr = something(unsafe_get_attr(COMM_WORLD, Consts.MPI_TAG_UB[]))
290+
ptr = something(unsafe_get_attr(COMM_WORLD, API.MPI_TAG_UB[]))
291291
return Int(unsafe_load(Ptr{Cint}(ptr)))
292292
end
293293

0 commit comments

Comments
 (0)