Skip to content

Commit 9c98251

Browse files
committed
resolving type ambiguity
1 parent 90dbe0c commit 9c98251

File tree

2 files changed

+65
-68
lines changed

2 files changed

+65
-68
lines changed

ext/ModelingToolkitSIExt.jl

Lines changed: 59 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,15 @@ function __mtk_to_si(
282282
end
283283
# -----------------------------------------------------------------------------
284284
"""
285-
function assess_local_identifiability(ode::ModelingToolkit.System; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)
285+
function assess_local_identifiability(sys::ModelingToolkit.System; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)
286286
287287
Input:
288-
- `ode` - the System object from ModelingToolkit
288+
- `ode` - the System object from ModelingToolkit (could represent an ODE or a discrete-time dynamical system)
289289
- `measured_quantities` - the measurable outputs of the model
290290
- `funcs_to_check` - functions of parameters for which to check identifiability
291+
- `known_ic` - functions of states (e.g., some of the states) for which initial conditions are assumed to be known (and generic)
291292
- `prob_threshold` - probability of correctness
292-
- `type` - identifiability type (`:SE` for single-experiment, `:ME` for multi-experiment)
293+
- `type` - identifiability type (`:SE` for single-experiment, `:ME` for multi-experiment). `:ME` not implemented for discrete-time systems
293294
- `loglevel` - the minimal level of log messages to display (`Logging.Info` by default)
294295
295296
Output:
@@ -299,34 +300,50 @@ Output:
299300
The function determines local identifiability of parameters in `funcs_to_check` or all possible parameters if `funcs_to_check` is empty
300301
301302
The result is correct with probability at least `prob_threshold`.
302-
303-
`type` can be either `:SE` (single-experiment identifiability) or `:ME` (multi-experiment identifiability).
304-
The return value is a tuple consisting of the array of bools and the number of experiments to be performed.
305303
"""
306304
function StructuralIdentifiability.assess_local_identifiability(
307-
ode::ModelingToolkit.System;
305+
sys::ModelingToolkit.System;
308306
measured_quantities = ModelingToolkit.Equation[],
309307
funcs_to_check = Array{}[],
308+
known_ic = [],
310309
prob_threshold::Float64 = 0.99,
311310
type = :SE,
312311
loglevel = Logging.Info,
313312
)
314313
restart_logging(loglevel = loglevel)
315314
with_logger(_si_logger[]) do
316-
return _assess_local_identifiability(
317-
ode,
318-
measured_quantities = measured_quantities,
319-
funcs_to_check = funcs_to_check,
320-
prob_threshold = prob_threshold,
321-
type = type,
322-
)
315+
if any(ModelingToolkit.hasshift, equations(sys))
316+
if type == :ME
317+
throw(
318+
"Only single-experiment identifiability is implemented in the discrete-time case",
319+
)
320+
else
321+
return _assess_local_identifiability_dds(
322+
sys,
323+
measured_quantities = measured_quantities,
324+
funcs_to_check = funcs_to_check,
325+
known_ic = known_ic,
326+
prob_threshold = prob_threshold,
327+
)
328+
end
329+
else
330+
return _assess_local_identifiability_ode(
331+
sys,
332+
measured_quantities = measured_quantities,
333+
funcs_to_check = funcs_to_check,
334+
known_ic = known_ic,
335+
prob_threshold = prob_threshold,
336+
type = type,
337+
)
338+
end
323339
end
324340
end
325341

326-
@timeit _to function _assess_local_identifiability(
342+
@timeit _to function _assess_local_identifiability_ode(
327343
ode::ModelingToolkit.System;
328344
measured_quantities = Array{ModelingToolkit.Equation}[],
329345
funcs_to_check = Array{}[],
346+
known_ic = [],
330347
prob_threshold::Float64 = 0.99,
331348
type = :SE,
332349
)
@@ -338,19 +355,39 @@ end
338355
end
339356

340357
funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]
358+
known_ic_ = [eval_at_nemo(each, conversion) for each in known_ic]
341359

342360
if isequal(type, :SE)
343-
result = StructuralIdentifiability._assess_local_identifiability(
344-
ode,
345-
funcs_to_check = funcs_to_check_,
346-
prob_threshold = prob_threshold,
347-
type = type,
348-
)
361+
if isempty(known_ic)
362+
result = StructuralIdentifiability._assess_local_identifiability(
363+
ode,
364+
funcs_to_check = funcs_to_check_,
365+
prob_threshold = prob_threshold,
366+
type = type,
367+
)
368+
else
369+
result = StructuralIdentifiability._assess_local_identifiability_kic(
370+
ode,
371+
funcs_to_check = funcs_to_check_,
372+
prob_threshold = prob_threshold,
373+
known_ic = known_ic_,
374+
)
375+
end
349376
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
350377
out_dict =
351378
OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
379+
if length(known_ic) > 0
380+
@warn "Since known initial conditions were provided, identifiability of states (e.g., `x(t)`) is at t = 0 only !"
381+
t = SymbolicUtils.Sym{Real}(:t)
382+
out_dict = OrderedDict(substitute(k, Dict(t => 0)) => v for (k, v) in out_dict)
383+
end
352384
return out_dict
353385
elseif isequal(type, :ME)
386+
if !isempty(known_ic)
387+
throw(
388+
"Known initail conditions are not well-defined in the multi-experimental regime",
389+
)
390+
end
354391
result, bd = StructuralIdentifiability._assess_local_identifiability(
355392
ode,
356393
funcs_to_check = funcs_to_check_,
@@ -447,47 +484,7 @@ end
447484

448485
# ------------------------------------------------------------------------------
449486

450-
"""
451-
function assess_local_identifiability(
452-
dds::ModelingToolkit.System;
453-
measured_quantities=Array{ModelingToolkit.Equation}[],
454-
funcs_to_check=Array{}[],
455-
known_ic=Array{}[],
456-
prob_threshold::Float64=0.99)
457-
458-
Input:
459-
- `dds` - the System object from ModelingToolkit
460-
- `measured_quantities` - the measurable outputs of the model
461-
- `funcs_to_check` - functions of parameters for which to check identifiability (all parameters and states if not specified)
462-
- `known_ic` - functions (of states and parameter) whose initial conditions are assumed to be known
463-
- `prob_threshold` - probability of correctness
464-
465-
Output:
466-
- the result is an (ordered) dictionary from each function to to boolean;
467-
468-
The result is correct with probability at least `prob_threshold`.
469-
"""
470-
function StructuralIdentifiability.assess_local_identifiability(
471-
dds::ModelingToolkit.System;
472-
measured_quantities = Array{ModelingToolkit.Equation}[],
473-
funcs_to_check = Array{}[],
474-
known_ic = Array{}[],
475-
prob_threshold::Float64 = 0.99,
476-
loglevel = Logging.Info,
477-
)
478-
restart_logging(loglevel = loglevel)
479-
with_logger(_si_logger[]) do
480-
return _assess_local_identifiability(
481-
dds,
482-
measured_quantities = measured_quantities,
483-
funcs_to_check = funcs_to_check,
484-
known_ic = known_ic,
485-
prob_threshold = prob_threshold,
486-
)
487-
end
488-
end
489-
490-
function _assess_local_identifiability(
487+
function _assess_local_identifiability_dds(
491488
dds::ModelingToolkit.System;
492489
measured_quantities = Array{ModelingToolkit.Equation}[],
493490
funcs_to_check = Array{}[],

test/extensions/modelingtoolkit.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt"
568568
I(k) ~ I(k - 1) + β * S(k - 1) * I(k - 1) - α * I(k - 1),
569569
R(k) ~ R(k - 1) + α * I(k - 1),
570570
]
571-
@mtkbuild sir = System(eqs, t)
571+
@named sir = System(eqs, t)
572572
push!(
573573
cases,
574574
Dict(
@@ -586,7 +586,7 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt"
586586

587587
eqs = [x(k) ~ θ * x(k - 1)^3]
588588

589-
@mtkbuild eqs = System(eqs, t)
589+
@named eqs = System(eqs, t)
590590
push!(
591591
cases,
592592
Dict(
@@ -604,7 +604,7 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt"
604604

605605
eqs = [x1(k) ~ x1(k - 1) + x2(k - 1), x2(k) ~ x2(k - 1) + θ + β]
606606

607-
@mtkbuild eqs = System(eqs, t)
607+
@named eqs = System(eqs, t)
608608
push!(
609609
cases,
610610
Dict(
@@ -625,7 +625,7 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt"
625625
x2(k) ~ -c * x2(k - 1) + d * x1(k - 1) * x2(k - 1),
626626
]
627627

628-
@mtkbuild lv = System(eqs, t)
628+
@named lv = System(eqs, t)
629629
push!(
630630
cases,
631631
Dict(
@@ -767,7 +767,7 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt"
767767

768768
eqs = [x1(k) ~ x1(k - 1) + a]
769769

770-
@mtkbuild kic = System(eqs, t)
770+
@named kic = System(eqs, t)
771771
push!(
772772
cases,
773773
Dict(
@@ -958,7 +958,7 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt"
958958
beta * delta,
959959
alpha * gama,
960960
beta + delta,
961-
-delta * x3_0 + gama * x2_0,
961+
beta * x3_0 + gama * x2_0,
962962
],
963963
:correct_ident => OrderedDict(alpha => :locally, alpha * gama => :globally),
964964
),

0 commit comments

Comments
 (0)