16
16
is_dataclass_like ,
17
17
is_subclass ,
18
18
)
19
+ from ._namespace import Namespace
19
20
from ._optionals import get_doc_short_description , is_pydantic_model , pydantic_support
20
21
from ._parameter_resolvers import (
21
22
ParamData ,
29
30
get_subclasses_from_type ,
30
31
is_optional ,
31
32
)
32
- from ._util import get_private_kwargs , iter_to_set_str
33
+ from ._util import NoneType , get_private_kwargs , iter_to_set_str
33
34
from .typing import register_pydantic_type
34
35
35
36
__all__ = [
@@ -51,7 +52,7 @@ def add_class_arguments(
51
52
nested_key : Optional [str ] = None ,
52
53
as_group : bool = True ,
53
54
as_positional : bool = False ,
54
- default : Optional [LazyInitBaseClass ] = None ,
55
+ default : Optional [Union [ dict , Namespace , LazyInitBaseClass ] ] = None ,
55
56
skip : Optional [Set [Union [str , int ]]] = None ,
56
57
instantiate : bool = True ,
57
58
fail_untyped : bool = True ,
@@ -67,7 +68,7 @@ def add_class_arguments(
67
68
nested_key: Key for nested namespace.
68
69
as_group: Whether arguments should be added to a new argument group.
69
70
as_positional: Whether to add required parameters as positional arguments.
70
- default: Default value used to override parameter defaults. Must be lazy_instance.
71
+ default: Default value used to override parameter defaults.
71
72
skip: Names of parameters or number of positionals that should be skipped.
72
73
instantiate: Whether the class group should be instantiated by :code:`instantiate_classes`.
73
74
fail_untyped: Whether to raise exception if a required parameter does not have a type.
@@ -81,9 +82,14 @@ def add_class_arguments(
81
82
ValueError: When there are required parameters without at least one valid type.
82
83
"""
83
84
if not inspect .isclass (get_generic_origin (get_unaliased_type (theclass ))):
84
- raise ValueError (f'Expected "theclass" parameter to be a class type, got: { theclass } .' )
85
- if default and not (isinstance (default , LazyInitBaseClass ) and isinstance (default , theclass )):
86
- raise ValueError (f'Expected "default" parameter to be a lazy instance of the class, got: { default } .' )
85
+ raise ValueError (f"Expected 'theclass' parameter to be a class type, got: { theclass } " )
86
+ if not (
87
+ isinstance (default , (NoneType , dict , Namespace ))
88
+ or (isinstance (default , LazyInitBaseClass ) and isinstance (default , theclass ))
89
+ ):
90
+ raise ValueError (
91
+ f"Expected 'default' parameter to be a dict, Namespace or lazy instance of the class, got: { default } "
92
+ )
87
93
linked_targets = get_private_kwargs (kwargs , linked_targets = None )
88
94
89
95
added_args = self ._add_signature_arguments (
@@ -102,9 +108,13 @@ def add_class_arguments(
102
108
if default :
103
109
skip = skip or set ()
104
110
prefix = nested_key + "." if nested_key else ""
105
- defaults = default .lazy_get_init_args ()
111
+ defaults = default
112
+ if isinstance (default , LazyInitBaseClass ):
113
+ defaults = default .lazy_get_init_args ().as_dict ()
114
+ elif isinstance (default , Namespace ):
115
+ defaults = default .as_dict ()
106
116
if defaults :
107
- defaults = {prefix + k : v for k , v in defaults .__dict__ . items () if k not in skip }
117
+ defaults = {prefix + k : v for k , v in defaults .items () if k not in skip }
108
118
self .set_defaults (** defaults ) # type: ignore[attr-defined]
109
119
110
120
return added_args
0 commit comments