Skip to content

Commit 4787752

Browse files
authored
Fix: Optional pydantic model failing to parse with __pydantic_private__ error (#530).
1 parent 89577a3 commit 4787752

File tree

5 files changed

+42
-10
lines changed

5 files changed

+42
-10
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ Fixed
4040
- List of union of classes not accepted by ``add_subclass_arguments`` in
4141
``python>=3.11`` (`#522
4242
<https://github.com/omni-us/jsonargparse/pull/522>`__).
43+
- Optional pydantic model failing to parse with `__pydantic_private__` error
44+
(`#521 <https://github.com/omni-us/jsonargparse/issues/521>`__).
4345

4446

4547
v4.29.0 (2024-05-24)

jsonargparse/_signatures.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
is_dataclass_like,
1717
is_subclass,
1818
)
19+
from ._namespace import Namespace
1920
from ._optionals import get_doc_short_description, is_pydantic_model, pydantic_support
2021
from ._parameter_resolvers import (
2122
ParamData,
@@ -29,7 +30,7 @@
2930
get_subclasses_from_type,
3031
is_optional,
3132
)
32-
from ._util import get_private_kwargs, iter_to_set_str
33+
from ._util import NoneType, get_private_kwargs, iter_to_set_str
3334
from .typing import register_pydantic_type
3435

3536
__all__ = [
@@ -51,7 +52,7 @@ def add_class_arguments(
5152
nested_key: Optional[str] = None,
5253
as_group: bool = True,
5354
as_positional: bool = False,
54-
default: Optional[LazyInitBaseClass] = None,
55+
default: Optional[Union[dict, Namespace, LazyInitBaseClass]] = None,
5556
skip: Optional[Set[Union[str, int]]] = None,
5657
instantiate: bool = True,
5758
fail_untyped: bool = True,
@@ -67,7 +68,7 @@ def add_class_arguments(
6768
nested_key: Key for nested namespace.
6869
as_group: Whether arguments should be added to a new argument group.
6970
as_positional: Whether to add required parameters as positional arguments.
70-
default: Default value used to override parameter defaults. Must be lazy_instance.
71+
default: Default value used to override parameter defaults.
7172
skip: Names of parameters or number of positionals that should be skipped.
7273
instantiate: Whether the class group should be instantiated by :code:`instantiate_classes`.
7374
fail_untyped: Whether to raise exception if a required parameter does not have a type.
@@ -81,9 +82,14 @@ def add_class_arguments(
8182
ValueError: When there are required parameters without at least one valid type.
8283
"""
8384
if not inspect.isclass(get_generic_origin(get_unaliased_type(theclass))):
84-
raise ValueError(f'Expected "theclass" parameter to be a class type, got: {theclass}.')
85-
if default and not (isinstance(default, LazyInitBaseClass) and isinstance(default, theclass)):
86-
raise ValueError(f'Expected "default" parameter to be a lazy instance of the class, got: {default}.')
85+
raise ValueError(f"Expected 'theclass' parameter to be a class type, got: {theclass}")
86+
if not (
87+
isinstance(default, (NoneType, dict, Namespace))
88+
or (isinstance(default, LazyInitBaseClass) and isinstance(default, theclass))
89+
):
90+
raise ValueError(
91+
f"Expected 'default' parameter to be a dict, Namespace or lazy instance of the class, got: {default}"
92+
)
8793
linked_targets = get_private_kwargs(kwargs, linked_targets=None)
8894

8995
added_args = self._add_signature_arguments(
@@ -102,9 +108,13 @@ def add_class_arguments(
102108
if default:
103109
skip = skip or set()
104110
prefix = nested_key + "." if nested_key else ""
105-
defaults = default.lazy_get_init_args()
111+
defaults = default
112+
if isinstance(default, LazyInitBaseClass):
113+
defaults = default.lazy_get_init_args().as_dict()
114+
elif isinstance(default, Namespace):
115+
defaults = default.as_dict()
106116
if defaults:
107-
defaults = {prefix + k: v for k, v in defaults.__dict__.items() if k not in skip}
117+
defaults = {prefix + k: v for k, v in defaults.items() if k not in skip}
108118
self.set_defaults(**defaults) # type: ignore[attr-defined]
109119

110120
return added_args

jsonargparse/_typehints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ def adapt_typehints(
960960
elif is_dataclass_like(typehint):
961961
if isinstance(prev_val, (dict, Namespace)):
962962
assert isinstance(sub_add_kwargs, dict)
963-
sub_add_kwargs["default"] = lazy_instance(typehint, **prev_val)
963+
sub_add_kwargs["default"] = prev_val
964964
parser = ActionTypeHint.get_class_parser(typehint, sub_add_kwargs=sub_add_kwargs)
965965
if instantiate_classes:
966966
init_args = parser.instantiate_classes(val)

jsonargparse_tests/test_dataclass_like.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,10 @@ class PydanticHelp(pydantic.BaseModel):
680680
class PydanticAnnotatedField(pydantic.BaseModel):
681681
p1: annotated[int, pydantic.Field(default=2, ge=1, le=8)] # type: ignore[valid-type]
682682

683+
class OptionalPydantic:
684+
def __init__(self, a: Optional[PydanticModel] = None):
685+
self.a = a
686+
683687

684688
def none(x):
685689
return x
@@ -797,6 +801,22 @@ def test_dataclass_nested(self, parser):
797801
cfg = parser.parse_args(["--data", '{"p3": {"p1": 1.0}}'])
798802
assert cfg.data == Namespace(p3=Namespace(p1=1.0, p2="-"))
799803

804+
def test_optional_pydantic_model(self, parser):
805+
parser.add_argument("--b", type=OptionalPydantic)
806+
parser.add_argument("--cfg", action="config")
807+
cfg = parser.parse_args([f"--b={__name__}.OptionalPydantic"])
808+
assert cfg.b.class_path == f"{__name__}.OptionalPydantic"
809+
assert cfg.b.init_args == Namespace(a=None)
810+
config = {
811+
"b": {
812+
"class_path": f"{__name__}.OptionalPydantic",
813+
"init_args": {"a": {"p1": "x"}},
814+
}
815+
}
816+
cfg = parser.parse_args([f"--cfg={config}"])
817+
assert cfg.b.class_path == f"{__name__}.OptionalPydantic"
818+
assert cfg.b.init_args == Namespace(a=Namespace(p1="x", p2=3))
819+
800820

801821
# attrs tests
802822

jsonargparse_tests/test_signatures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
def test_add_class_failure_not_a_class(parser):
8181
with pytest.raises(ValueError) as ctx:
8282
parser.add_class_arguments("Not a class")
83-
ctx.match('Expected "theclass" parameter to be a class')
83+
ctx.match("Expected 'theclass' parameter to be a class")
8484

8585

8686
def test_add_class_failure_positional_without_type(parser):

0 commit comments

Comments
 (0)