Skip to content

Commit 8c26223

Browse files
authored
Fix not able to modify init args for callable with class return and default class. (#504)
1 parent 044652f commit 8c26223

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ Fixed
2828
space (`#499 <https://github.com/omni-us/jsonargparse/pull/499>`__).
2929
- ``format_usage()`` not working (`#501
3030
<https://github.com/omni-us/jsonargparse/issues/501>`__).
31+
- Not able to modify init args for callable with class return and default class
32+
(`#5?? <https://github.com/omni-us/jsonargparse/pull/5??>`__).
3133

3234

3335
v4.28.0 (2024-04-17)

jsonargparse/_typehints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,8 @@ def subclass_spec_as_namespace(val, prev_val=None):
980980
val = Namespace({root_key: val})
981981
if isinstance(prev_val, str):
982982
prev_val = Namespace(class_path=prev_val)
983+
elif inspect.isclass(prev_val):
984+
prev_val = Namespace(class_path=get_import_path(prev_val))
983985
if isinstance(val, dict):
984986
val = Namespace(val)
985987
if "init_args" in val and isinstance(val["init_args"], dict):

jsonargparse_tests/test_typehints.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,13 @@ def test_callable_multiple_args_return_type_class(parser, subtests):
792792
assert f"{__name__}.{name}" in help_str
793793

794794

795+
def test_callable_return_class_default_class_override_init_arg(parser):
796+
parser.add_argument("--optimizer", type=Callable[[List[float]], Optimizer], default=SGD)
797+
cfg = parser.parse_args(["--optimizer.momentum=0.5", "--optimizer.lr=0.05"])
798+
assert cfg.optimizer.class_path == f"{__name__}.SGD"
799+
assert cfg.optimizer.init_args == Namespace(lr=0.05, momentum=0.5)
800+
801+
795802
class StepLR:
796803
def __init__(self, optimizer: Optimizer, last_epoch: int = -1):
797804
self.optimizer = optimizer

0 commit comments

Comments
 (0)