Skip to content

Commit 9260101

Browse files
committed
Method to get link targets.
1 parent b7eed53 commit 9260101

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Added
2727
(`#698 <https://github.com/omni-us/jsonargparse/pull/698>`__).
2828
- Option to enable validation of default values (`#711
2929
<https://github.com/omni-us/jsonargparse/pull/711>`__).
30+
- New method to get a list of link targets (`#715
31+
<https://github.com/omni-us/jsonargparse/pull/715>`__).
3032

3133
Changed
3234
^^^^^^^

jsonargparse/_link_arguments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,7 @@ def link_arguments(
515515
ValueError: If an invalid parameter is given.
516516
"""
517517
ActionLink(self, source, target, compute_fn, apply_on)
518+
519+
def get_link_targets(self, apply_on: str) -> List[str]:
520+
"""Get all keys that are targets of links."""
521+
return [a.target[0] for a in get_link_actions(self, apply_on)] # type: ignore[arg-type]

jsonargparse_tests/test_link_arguments.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def test_on_parse_shallow_print_config(parser):
3535
parser.link_arguments("a", "b")
3636
out = get_parse_args_stdout(parser, ["--print_config"])
3737
assert json_or_yaml_load(out) == {"a": 0}
38+
assert parser.get_link_targets("parse") == ["b"]
3839

3940

4041
def test_on_parse_subcommand_failing_compute_fn(parser, subparser, subtests):
@@ -107,6 +108,9 @@ def test_on_parse_compute_fn_subclass_spec(parser, subtests):
107108
assert cfg.cal1.init_args.firstweekday == 2
108109
assert cfg.cal2.init_args.firstweekday == 3
109110

111+
with subtests.test("get_link_targets"):
112+
assert parser.get_link_targets("parse") == ["cal2.init_args.firstweekday"]
113+
110114
with subtests.test("invalid init parameter"):
111115
parser.set_defaults(cal1=None)
112116
with pytest.raises(ArgumentError) as ctx:
@@ -164,6 +168,9 @@ def test_on_parse_add_class_arguments(subtests):
164168
dump = json_or_yaml_load(parser.dump(cfg, skip_link_targets=False))
165169
assert dump == {"a": {"v1": 11, "v2": 7}, "b": {"v3": 2, "v1": 7, "v2": 18}}
166170

171+
with subtests.test("get_link_targets"):
172+
assert parser.get_link_targets("parse") == ["b.v1", "b.v2"]
173+
167174
with subtests.test("argument error"):
168175
pytest.raises(ArgumentError, lambda: parser.parse_args(["--b.v1=5"]))
169176

@@ -209,6 +216,9 @@ def add(v1, v2):
209216
dump = json_or_yaml_load(parser.dump(cfg, skip_link_targets=False))
210217
assert dump["s2"] == {"class_path": f"{__name__}.ClassS2", "init_args": {"v3": 4}}
211218

219+
with subtests.test("get_link_targets"):
220+
assert parser.get_link_targets("parse") == ["s2.init_args.v3"]
221+
212222
with subtests.test("compute_fn invalid result type"):
213223
s1_value["init_args"] = {"v1": "a", "v2": "b"}
214224
with pytest.raises(ArgumentError):
@@ -237,6 +247,7 @@ def test_on_parse_subclass_target_in_union(parser):
237247
cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=Logger"])
238248
assert cfg.trainer.save_dir == "logs"
239249
assert cfg.trainer.logger.init_args == Namespace(save_dir="logs")
250+
assert parser.get_link_targets("parse") == ["trainer.logger.init_args.save_dir"]
240251

241252

242253
class TrainerLoggerList:
@@ -530,6 +541,7 @@ def test_on_instantiate_link_instance_attribute():
530541
init = parser.instantiate_classes(cfg)
531542
assert init.x.x1 == 6
532543
assert init.y.y3 == '"8"'
544+
assert parser.get_link_targets("instantiate") == ["x.x1", "y.y1", "y.y3"]
533545

534546

535547
def test_on_instantiate_link_all_group_arguments():
@@ -542,6 +554,7 @@ def test_on_instantiate_link_all_group_arguments():
542554
assert init["x"].x2 == 7
543555
help_str = get_parser_help(parser)
544556
assert "Group 'x': All arguments are derived from links" in help_str
557+
assert parser.get_link_targets("instantiate") == ["x.x1", "y.y1", "x.x2"]
545558

546559

547560
class FailingComputeFn1:
@@ -600,6 +613,7 @@ def test_on_parse_and_instantiate_link_entire_instance(parser):
600613
assert isinstance(init.n, Namespace)
601614
assert isinstance(init.c, Calendar)
602615
assert init.c is init.n.calendar
616+
assert parser.get_link_targets("instantiate") == ["n.calendar"]
603617

604618

605619
class ClassM:
@@ -645,6 +659,7 @@ def test_on_instantiate_link_object_in_attribute(parser):
645659
init = parser.instantiate_classes(cfg)
646660
assert init.p.calendar is init.q.calendar
647661
assert init.q.calendar.firstweekday == 2
662+
assert parser.get_link_targets("instantiate") == ["q.calendar"]
648663

649664

650665
def test_on_parse_link_entire_subclass(parser):
@@ -656,6 +671,7 @@ def test_on_parse_link_entire_subclass(parser):
656671
cfg = parser.parse_args([f"--n.calendar={json.dumps(cal)}", "--q.q2=7"])
657672
assert cfg.n.calendar == cfg.q.calendar
658673
assert cfg.q.q2 == 7
674+
assert parser.get_link_targets("parse") == ["q.calendar"]
659675

660676

661677
class ClassV:
@@ -789,6 +805,7 @@ def test_on_instantiate_within_deep_subclass(parser, caplog):
789805
assert isinstance(init.model.decoder, WithinDeepTarget)
790806
assert init.model.decoder.input_channels == 16
791807
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text
808+
assert parser.get_link_targets("instantiate") == ["model.init_args.decoder.init_args.input_channels"]
792809

793810

794811
class WithinDeeperSystem:
@@ -824,6 +841,9 @@ def test_on_instantiate_within_deeper_subclass(parser, caplog):
824841
assert isinstance(init.system.model.decoder, WithinDeepTarget)
825842
assert init.system.model.decoder.input_channels == 16
826843
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text
844+
assert parser.get_link_targets("instantiate") == [
845+
"system.init_args.model.init_args.decoder.init_args.input_channels"
846+
]
827847

828848

829849
class SourceA:

0 commit comments

Comments
 (0)