Skip to content

Commit 5282c0b

Browse files
authored
Fix AttributeError when using union of dtypes in array expression (#52)
1 parent 3d6a3e1 commit 5282c0b

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

src/docstub/_docstrings.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,23 @@
3232

3333

3434
def _find_one_token(tree: lark.Tree, *, name: str) -> lark.Token:
35-
"""Find token with a specific type name in tree."""
36-
tokens = [child for child in tree.children if child.type == name]
35+
"""Find token with a specific type name in tree.
36+
37+
Parameters
38+
----------
39+
tree : lark.Tree
40+
name : str
41+
Name of the token to find in the children of `tree`.
42+
43+
Returns
44+
-------
45+
token : lark.Token
46+
"""
47+
tokens = [
48+
child
49+
for child in tree.children
50+
if hasattr(child, "type") and child.type == name
51+
]
3752
if len(tokens) != 1:
3853
msg = f"expected exactly one Token of type {name}, found {len(tokens)}"
3954
raise ValueError(msg)

tests/test_docstrings.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ def test_rst_role(self, doctype, expected):
185185
)
186186
@pytest.mark.parametrize("name", ["array", "ndarray", "array-like", "array_like"])
187187
@pytest.mark.parametrize("dtype", ["int", "np.int8"])
188-
@pytest.mark.parametrize("shape", ["(2, 3)", "(N, m)", "3D", "2-D", "(N, ...)"])
188+
@pytest.mark.parametrize("shape",
189+
["(2, 3)", "(N, m)", "3D", "2-D", "(N, ...)", "([P,] M, N)"]
190+
)
189191
def test_natlang_array(self, fmt, expected_fmt, name, dtype, shape):
190192

191193
def escape(name: str) -> str:
@@ -202,6 +204,18 @@ def escape(name: str) -> str:
202204
assert annotation.value == expected
203205
# fmt: on
204206

207+
@pytest.mark.parametrize(
208+
("doctype", "expected"),
209+
[
210+
("ndarray of dtype (int or float)", "ndarray[int | float]"),
211+
("([P,] M, N) (int or float) array", "array[int | float]"),
212+
],
213+
)
214+
def test_natlang_array_specific(self, doctype, expected):
215+
transformer = DoctypeTransformer()
216+
annotation, _ = transformer.doctype_to_annotation(doctype)
217+
assert annotation.value == expected
218+
205219
@pytest.mark.parametrize("shape", ["(-1, 3)", "(1.0, 2)", "-3D", "-2-D"])
206220
def test_natlang_array_invalid_shape(self, shape):
207221
doctype = f"array of shape {shape}"

0 commit comments

Comments
 (0)