Skip to content

Commit ba82052

Browse files
committed
Experimental support for custom instantiators receiving values applied by instantiation links (Lightning-AI/pytorch-lightning#20311).
1 parent b7eed53 commit ba82052

File tree

5 files changed

+64
-1
lines changed

5 files changed

+64
-1
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ Added
2727
(`#698 <https://github.com/omni-us/jsonargparse/pull/698>`__).
2828
- Option to enable validation of default values (`#711
2929
<https://github.com/omni-us/jsonargparse/pull/711>`__).
30+
- Experimental support for custom instantiators receiving values applied by
31+
instantiation links (`#716
32+
<https://github.com/omni-us/jsonargparse/pull/716>`__).
3033

3134
Changed
3235
^^^^^^^

jsonargparse/_common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
6464
load_value_mode: ContextVar[Optional[str]] = ContextVar("load_value_mode", default=None)
6565
class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators", default=None)
6666
nested_links: ContextVar[List[dict]] = ContextVar("nested_links", default=[])
67+
applied_instantiation_links: ContextVar[Optional[set]] = ContextVar("applied_instantiation_links", default=None)
6768

6869

6970
parser_context_vars = dict(
@@ -74,6 +75,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
7475
load_value_mode=load_value_mode,
7576
class_instantiators=class_instantiators,
7677
nested_links=nested_links,
78+
applied_instantiation_links=applied_instantiation_links,
7779
)
7880

7981

@@ -270,6 +272,12 @@ def __init__(self, instantiators: InstantiatorsDictType) -> None:
270272
def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
271273
for (cls, subclasses), instantiator in self.instantiators.items():
272274
if class_type is cls or (subclasses and is_subclass(class_type, cls)):
275+
param_names = set(inspect.signature(instantiator).parameters.keys())
276+
if "applied_instantiation_links" in param_names:
277+
applied_links = applied_instantiation_links.get() or set()
278+
kwargs["applied_instantiation_links"] = {
279+
action.target[0]: action.applied_value for action in applied_links
280+
}
273281
return instantiator(class_type, *args, **kwargs)
274282
return default_class_instantiator(class_type, *args, **kwargs)
275283

jsonargparse/_core.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,14 @@ def add_instantiator(
11741174
For reference, the default instantiator is ``return class_type(*args,
11751175
**kwargs)``.
11761176
1177+
In some use cases, the instantiator function might need access to values
1178+
applied by instantiation links. For this, the instantiator function can
1179+
have an additional keyword parameter ``applied_instantiation_links:
1180+
dict``. This parameter will be populated with a dictionary having as
1181+
keys the targets of the instantiation links and corresponding values
1182+
that were applied. Support for ``applied_instantiation_links`` parameter
1183+
is EXPERIMENTAL and subject to change or removal in future versions.
1184+
11771185
Args:
11781186
instantiator: Function that instantiates a class.
11791187
class_type: The class type to instantiate.
@@ -1246,10 +1254,15 @@ def instantiate_classes(
12461254
parent_parser=self,
12471255
nested_links=ActionLink.get_nested_links(self, component),
12481256
class_instantiators=self._get_instantiators(),
1257+
applied_instantiation_links=cfg.get("__applied_instantiation_links__"),
12491258
):
12501259
parent[key] = component.instantiate_classes(value)
12511260
else:
1252-
with parser_context(load_value_mode=self.parser_mode, class_instantiators=self._get_instantiators()):
1261+
with parser_context(
1262+
load_value_mode=self.parser_mode,
1263+
class_instantiators=self._get_instantiators(),
1264+
applied_instantiation_links=cfg.get("__applied_instantiation_links__"),
1265+
):
12531266
component.instantiate_class(component, cfg)
12541267

12551268
ActionLink.apply_instantiation_links(self, cfg, order=order)

jsonargparse/_link_arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def apply_instantiation_links(parser, cfg, target=None, order=None):
363363
else:
364364
value = action.call_compute_fn(source_objects)
365365
ActionLink.set_target_value(action, value, cfg, parser.logger)
366+
action.applied_value = value
366367
applied_links.add(action)
367368
parser.logger.debug(f"Applied link '{action.option_strings[0]}'.")
368369

jsonargparse_tests/test_link_arguments.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,44 @@ def test_on_instantiate_targets_share_parent(parser):
900900
assert init.root.child.attr_child is init.source_b.attr_b
901901

902902

903+
class Dataloader:
904+
def __init__(self, batch_size: int = 6):
905+
self.batch_size = batch_size
906+
self.num_classes = 7
907+
908+
909+
class CustomOptimizer(Optimizer):
910+
def __init__(self, params: List[int], num_classes: int, **kwargs):
911+
super().__init__(params, **kwargs)
912+
913+
914+
def custom_instantiator(class_type, *args, applied_instantiation_links: dict, **kwargs):
915+
init = class_type(*args, **kwargs)
916+
init.applied_instantiation_links = applied_instantiation_links
917+
return init
918+
919+
920+
def test_on_instantiate_targets_passed_to_instantiator(parser):
921+
parser.add_argument("--data", type=Dataloader)
922+
parser.add_argument("--model", type=Model)
923+
parser.link_arguments(
924+
"data.num_classes",
925+
"model.init_args.optimizer.init_args.num_classes",
926+
apply_on="instantiate",
927+
)
928+
parser.add_instantiator(custom_instantiator, Dataloader, subclasses=True)
929+
parser.add_instantiator(custom_instantiator, Model, subclasses=True)
930+
931+
cfg = parser.parse_args(["--data=Dataloader", "--model=Model", "--model.label=ok"])
932+
init = parser.instantiate_classes(cfg)
933+
934+
assert isinstance(init.data, Dataloader)
935+
assert init.data.applied_instantiation_links == {}
936+
assert isinstance(init.model, Model)
937+
assert callable(init.model.optimizer)
938+
assert init.model.applied_instantiation_links == {"model.init_args.optimizer.init_args.num_classes": 7}
939+
940+
903941
# link creation failures
904942

905943

0 commit comments

Comments
 (0)