@@ -35,6 +35,7 @@ def test_on_parse_shallow_print_config(parser):
35
35
parser .link_arguments ("a" , "b" )
36
36
out = get_parse_args_stdout (parser , ["--print_config" ])
37
37
assert json_or_yaml_load (out ) == {"a" : 0 }
38
+ assert parser .get_link_targets ("parse" ) == ["b" ]
38
39
39
40
40
41
def test_on_parse_subcommand_failing_compute_fn (parser , subparser , subtests ):
@@ -107,6 +108,9 @@ def test_on_parse_compute_fn_subclass_spec(parser, subtests):
107
108
assert cfg .cal1 .init_args .firstweekday == 2
108
109
assert cfg .cal2 .init_args .firstweekday == 3
109
110
111
+ with subtests .test ("get_link_targets" ):
112
+ assert parser .get_link_targets ("parse" ) == ["cal2.init_args.firstweekday" ]
113
+
110
114
with subtests .test ("invalid init parameter" ):
111
115
parser .set_defaults (cal1 = None )
112
116
with pytest .raises (ArgumentError ) as ctx :
@@ -164,6 +168,9 @@ def test_on_parse_add_class_arguments(subtests):
164
168
dump = json_or_yaml_load (parser .dump (cfg , skip_link_targets = False ))
165
169
assert dump == {"a" : {"v1" : 11 , "v2" : 7 }, "b" : {"v3" : 2 , "v1" : 7 , "v2" : 18 }}
166
170
171
+ with subtests .test ("get_link_targets" ):
172
+ assert parser .get_link_targets ("parse" ) == ["b.v1" , "b.v2" ]
173
+
167
174
with subtests .test ("argument error" ):
168
175
pytest .raises (ArgumentError , lambda : parser .parse_args (["--b.v1=5" ]))
169
176
@@ -209,6 +216,9 @@ def add(v1, v2):
209
216
dump = json_or_yaml_load (parser .dump (cfg , skip_link_targets = False ))
210
217
assert dump ["s2" ] == {"class_path" : f"{ __name__ } .ClassS2" , "init_args" : {"v3" : 4 }}
211
218
219
+ with subtests .test ("get_link_targets" ):
220
+ assert parser .get_link_targets ("parse" ) == ["s2.init_args.v3" ]
221
+
212
222
with subtests .test ("compute_fn invalid result type" ):
213
223
s1_value ["init_args" ] = {"v1" : "a" , "v2" : "b" }
214
224
with pytest .raises (ArgumentError ):
@@ -237,6 +247,7 @@ def test_on_parse_subclass_target_in_union(parser):
237
247
cfg = parser .parse_args (["--trainer.save_dir=logs" , "--trainer.logger=Logger" ])
238
248
assert cfg .trainer .save_dir == "logs"
239
249
assert cfg .trainer .logger .init_args == Namespace (save_dir = "logs" )
250
+ assert parser .get_link_targets ("parse" ) == ["trainer.logger.init_args.save_dir" ]
240
251
241
252
242
253
class TrainerLoggerList :
@@ -530,6 +541,7 @@ def test_on_instantiate_link_instance_attribute():
530
541
init = parser .instantiate_classes (cfg )
531
542
assert init .x .x1 == 6
532
543
assert init .y .y3 == '"8"'
544
+ assert parser .get_link_targets ("instantiate" ) == ["x.x1" , "y.y1" , "y.y3" ]
533
545
534
546
535
547
def test_on_instantiate_link_all_group_arguments ():
@@ -542,6 +554,7 @@ def test_on_instantiate_link_all_group_arguments():
542
554
assert init ["x" ].x2 == 7
543
555
help_str = get_parser_help (parser )
544
556
assert "Group 'x': All arguments are derived from links" in help_str
557
+ assert parser .get_link_targets ("instantiate" ) == ["x.x1" , "y.y1" , "x.x2" ]
545
558
546
559
547
560
class FailingComputeFn1 :
@@ -600,6 +613,7 @@ def test_on_parse_and_instantiate_link_entire_instance(parser):
600
613
assert isinstance (init .n , Namespace )
601
614
assert isinstance (init .c , Calendar )
602
615
assert init .c is init .n .calendar
616
+ assert parser .get_link_targets ("instantiate" ) == ["n.calendar" ]
603
617
604
618
605
619
class ClassM :
@@ -645,6 +659,7 @@ def test_on_instantiate_link_object_in_attribute(parser):
645
659
init = parser .instantiate_classes (cfg )
646
660
assert init .p .calendar is init .q .calendar
647
661
assert init .q .calendar .firstweekday == 2
662
+ assert parser .get_link_targets ("instantiate" ) == ["q.calendar" ]
648
663
649
664
650
665
def test_on_parse_link_entire_subclass (parser ):
@@ -656,6 +671,7 @@ def test_on_parse_link_entire_subclass(parser):
656
671
cfg = parser .parse_args ([f"--n.calendar={ json .dumps (cal )} " , "--q.q2=7" ])
657
672
assert cfg .n .calendar == cfg .q .calendar
658
673
assert cfg .q .q2 == 7
674
+ assert parser .get_link_targets ("parse" ) == ["q.calendar" ]
659
675
660
676
661
677
class ClassV :
@@ -789,6 +805,7 @@ def test_on_instantiate_within_deep_subclass(parser, caplog):
789
805
assert isinstance (init .model .decoder , WithinDeepTarget )
790
806
assert init .model .decoder .input_channels == 16
791
807
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog .text
808
+ assert parser .get_link_targets ("instantiate" ) == ["model.init_args.decoder.init_args.input_channels" ]
792
809
793
810
794
811
class WithinDeeperSystem :
@@ -824,6 +841,9 @@ def test_on_instantiate_within_deeper_subclass(parser, caplog):
824
841
assert isinstance (init .system .model .decoder , WithinDeepTarget )
825
842
assert init .system .model .decoder .input_channels == 16
826
843
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog .text
844
+ assert parser .get_link_targets ("instantiate" ) == [
845
+ "system.init_args.model.init_args.decoder.init_args.input_channels"
846
+ ]
827
847
828
848
829
849
class SourceA :
0 commit comments