Skip to content

Commit b0818b9

Browse files
authored
Add Aqua tests (#775)
* Add Aqua tests * Fix logpdf(::NamedDist) method ambiguity * Fix SimpleVarInfo method ambiguity * Fix VarInfo method ambiguity * Add InteractiveUtils compat entry See: https://discourse.julialang.org/t/psa-compat-requirements-in-the-general-registry-are-changing/104958 * Add Random.AbstractRNG type annotation * Remove unneeded getsym method * Fix (newly introduced 😅) ConditionContext method ambiguity * Fix unwrap_right_left_vns method ambiguity * KernelAbstractions is a weakdep not a dep * Fix StaticTransformation / ThreadSafeVarInfo link/invlink ambiguity * Fix more RNGs * Don't run Aqua tests on CI min versions * Fix ternary in GitHub Actions expression
1 parent 4bc43a4 commit b0818b9

14 files changed

+61
-13
lines changed

.github/workflows/CI.yml

+2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ jobs:
7070
env:
7171
GROUP: ${{ matrix.test_group }}
7272
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}
73+
# Only run Aqua tests on latest version
74+
AQUA: ${{ matrix.runner.version == '1' && 'true' || 'false' }}
7375

7476
- uses: julia-actions/julia-processcoverage@v1
7577

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1616
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1717
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1818
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
19-
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2019
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2120
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2221
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -30,6 +29,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3029
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3130
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3231
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
32+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3333
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3434
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3535

@@ -56,6 +56,7 @@ Distributions = "0.25"
5656
DocStringExtensions = "0.9"
5757
EnzymeCore = "0.6 - 0.8"
5858
ForwardDiff = "0.10.12"
59+
InteractiveUtils = "1"
5960
JET = "0.9"
6061
KernelAbstractions = "0.9.33"
6162
LinearAlgebra = "1.6"

benchmarks/src/DynamicPPLBenchmarks.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using LogDensityProblems: LogDensityProblems
99
using ForwardDiff: ForwardDiff
1010
using Mooncake: Mooncake
1111
using ReverseDiff: ReverseDiff
12+
using StableRNGs: StableRNG
1213

1314
include("./Models.jl")
1415
using .Models: Models
@@ -61,18 +62,20 @@ The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversedi
6162
`islinked` determines whether to link the VarInfo for evaluation.
6263
"""
6364
function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
65+
rng = StableRNG(23)
66+
6467
suite = BenchmarkGroup()
6568

6669
vi = if varinfo_choice == :untyped
6770
vi = VarInfo()
68-
model(vi)
71+
model(rng, vi)
6972
vi
7073
elseif varinfo_choice == :typed
71-
VarInfo(model)
74+
VarInfo(rng, model)
7275
elseif varinfo_choice == :simple_namedtuple
73-
SimpleVarInfo{Float64}(model())
76+
SimpleVarInfo{Float64}(model(rng))
7477
elseif varinfo_choice == :simple_dict
75-
retvals = model()
78+
retvals = model(rng)
7679
vns = [VarName{k}() for k in keys(retvals)]
7780
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
7881
else

src/compiler.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,10 @@ x[1][3]
250250
```
251251
"""
252252
unwrap_right_left_vns(right, left, vns) = right, left, vns
253-
function unwrap_right_left_vns(right::NamedDist, left, vns)
253+
function unwrap_right_left_vns(right::NamedDist, left::AbstractArray, ::VarName)
254+
return unwrap_right_left_vns(right.dist, left, right.name)
255+
end
256+
function unwrap_right_left_vns(right::NamedDist, left::AbstractMatrix, ::VarName)
254257
return unwrap_right_left_vns(right.dist, left, right.name)
255258
end
256259
function unwrap_right_left_vns(

src/context_implementations.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ function tilde_observe!!(context, right, left, vi)
195195
return left, acclogp_observe!!(context, vi, logp)
196196
end
197197

198-
function assume(rng, spl::Sampler, dist)
198+
function assume(rng::Random.AbstractRNG, spl::Sampler, dist)
199199
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
200200
end
201201

src/contexts.jl

+2
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,8 @@ function ConditionContext(values::Union{NamedTuple,AbstractDict})
335335
end
336336
# Optimisation when there are no values to condition on
337337
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
338+
# Same as above, and avoids method ambiguity with below
339+
ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context
338340
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
339341
# values inside the child context, thus giving precedence to the outermost
340342
# `ConditionContext`.

src/distribution_wrappers.jl

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ Base.length(dist::NamedDist) = Base.length(dist.dist)
1717
Base.size(dist::NamedDist) = Base.size(dist.dist)
1818

1919
Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
20+
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real,0})
21+
# extract the singleton value from 0-dimensional array
22+
return Distributions.logpdf(dist.dist, first(x))
23+
end
2024
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
2125
return Distributions.logpdf(dist.dist, x)
2226
end

src/simple_varinfo.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,14 @@ function SimpleVarInfo(; kwargs...)
232232
end
233233

234234
# Constructor from `Model`.
235-
SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...)
236-
function SimpleVarInfo{T}(model::Model, args...) where {T<:Real}
235+
function SimpleVarInfo(
236+
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
237+
)
238+
return SimpleVarInfo{Float64}(model, args...)
239+
end
240+
function SimpleVarInfo{T}(
241+
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
242+
) where {T<:Real}
237243
return last(evaluate!!(model, SimpleVarInfo{T}(), args...))
238244
end
239245

src/threadsafe.jl

+13
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@ function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
115115
return invlink!!(t, deepcopy(vi), model)
116116
end
117117

118+
# These two StaticTransformation methods needed to resolve ambiguities
119+
function link!!(
120+
t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model
121+
)
122+
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model)
123+
end
124+
125+
function invlink!!(
126+
t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model
127+
)
128+
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model)
129+
end
130+
118131
function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model)
119132
# Defer to the wrapped `AbstractVarInfo` object.
120133
# NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the

src/varinfo.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,11 @@ function VarInfo(
200200
)
201201
return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata)
202202
end
203-
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
203+
function VarInfo(
204+
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
205+
)
206+
return VarInfo(Random.default_rng(), model, args...)
207+
end
204208

205209
"""
206210
vector_length(varinfo::VarInfo)

src/varname.jl

-3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,3 @@ Possibly existing indices of `varname` are neglected.
4141
) where {s,missings,_F,_a,_T}
4242
return s in missings
4343
end
44-
45-
# HACK: Type-piracy. Is this really the way to go?
46-
AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym

test/Aqua.jl

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module AquaTests
2+
3+
using Aqua: Aqua
4+
using DynamicPPL
5+
6+
Aqua.test_all(DynamicPPL)
7+
8+
end

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
44
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
55
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
6+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
67
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
78
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
89
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

test/runtests.jl

+4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ using OrderedCollections: OrderedSet
3535
using DynamicPPL: getargs_dottilde, getargs_tilde
3636

3737
const GROUP = get(ENV, "GROUP", "All")
38+
const AQUA = get(ENV, "AQUA", "true") == "true"
3839
Random.seed!(100)
3940

4041
include("test_util.jl")
@@ -44,6 +45,9 @@ include("test_util.jl")
4445
# groups are chosen to make both groups take roughly the same amount of
4546
# time, but beyond that there is no particular reason for the split.
4647
if GROUP == "All" || GROUP == "Group1"
48+
if AQUA
49+
include("Aqua.jl")
50+
end
4751
include("utils.jl")
4852
include("compiler.jl")
4953
include("varnamedvector.jl")

0 commit comments

Comments
 (0)