@@ -185,7 +185,9 @@ def test_rst_role(self, doctype, expected):
185
185
)
186
186
@pytest .mark .parametrize ("name" , ["array" , "ndarray" , "array-like" , "array_like" ])
187
187
@pytest .mark .parametrize ("dtype" , ["int" , "np.int8" ])
188
- @pytest .mark .parametrize ("shape" , ["(2, 3)" , "(N, m)" , "3D" , "2-D" , "(N, ...)" ])
188
+ @pytest .mark .parametrize ("shape" ,
189
+ ["(2, 3)" , "(N, m)" , "3D" , "2-D" , "(N, ...)" , "([P,] M, N)" ]
190
+ )
189
191
def test_natlang_array (self , fmt , expected_fmt , name , dtype , shape ):
190
192
191
193
def escape (name : str ) -> str :
@@ -202,6 +204,18 @@ def escape(name: str) -> str:
202
204
assert annotation .value == expected
203
205
# fmt: on
204
206
207
+ @pytest .mark .parametrize (
208
+ ("doctype" , "expected" ),
209
+ [
210
+ ("ndarray of dtype (int or float)" , "ndarray[int | float]" ),
211
+ ("([P,] M, N) (int or float) array" , "array[int | float]" ),
212
+ ],
213
+ )
214
+ def test_natlang_array_specific (self , doctype , expected ):
215
+ transformer = DoctypeTransformer ()
216
+ annotation , _ = transformer .doctype_to_annotation (doctype )
217
+ assert annotation .value == expected
218
+
205
219
@pytest .mark .parametrize ("shape" , ["(-1, 3)" , "(1.0, 2)" , "-3D" , "-2-D" ])
206
220
def test_natlang_array_invalid_shape (self , shape ):
207
221
doctype = f"array of shape { shape } "
0 commit comments