2
2
3
3
import ctypes
4
4
import dataclasses
5
+ import functools
5
6
import inspect
7
+ import re
6
8
import typing
7
9
from collections .abc import Callable
8
10
from io import StringIO
16
18
TypeMethod ,
17
19
type_add_method ,
18
20
type_index2type_methods ,
21
+ type_table ,
19
22
)
20
23
from mlc .core import typing as mlc_typing
21
24
@@ -325,13 +328,9 @@ def add_vtable_methods_for_type_cls(type_cls: type, type_index: int) -> None:
325
328
type_add_method (type_index , name , func , kind = 0 )
326
329
327
330
328
- def prototype_py (type_info : type | TypeInfo ) -> str :
329
- if not isinstance (type_info , TypeInfo ):
330
- if (type_info := getattr (type_info , "_mlc_type_info" , None )) is None : # type: ignore[assignment]
331
- raise ValueError (f"Invalid type: { type_info } " )
331
+ def _prototype_py (type_info : TypeInfo ) -> str :
332
332
assert isinstance (type_info , TypeInfo )
333
333
cls_name = type_info .type_key .rsplit ("." , maxsplit = 1 )[- 1 ]
334
-
335
334
io = StringIO ()
336
335
print (f"@mlc.dataclasses.c_class({ type_info .type_key !r} )" , file = io )
337
336
print (f"class { cls_name } :" , file = io )
@@ -352,12 +351,11 @@ def prototype_py(type_info: type | TypeInfo) -> str:
352
351
return io .getvalue ().rstrip ()
353
352
354
353
355
- def prototype_cxx ( type_info : type | TypeInfo ) -> str :
356
- if not isinstance ( type_info , TypeInfo ):
357
- if ( type_info := getattr ( type_info , "_mlc_type_info" , None )) is None : # type: ignore[assignment]
358
- raise ValueError ( f"Invalid type: { type_info } " )
354
+ def _prototype_cxx (
355
+ type_info : TypeInfo ,
356
+ export_macro : str = "_EXPORTS" ,
357
+ ) -> str :
359
358
assert isinstance (type_info , TypeInfo )
360
-
361
359
parent_type_info = type_info .get_parent ()
362
360
namespaces = type_info .type_key .split ("." )
363
361
cls_name = namespaces [- 1 ]
@@ -388,22 +386,22 @@ def prototype_cxx(type_info: type | TypeInfo) -> str:
388
386
if i != 0 :
389
387
print (", " , file = io , end = "" )
390
388
print (f"{ ty } { name } " , file = io , end = "" )
391
- print ("): " , file = io , end = "" )
392
- for i , (name , _ ) in enumerate (fields ):
393
- if i != 0 :
394
- print (", " , file = io , end = "" )
395
- print (f"{ name } ({ name } )" , file = io , end = "" )
389
+ print ("): _mlc_header{}" , file = io , end = "" )
390
+ for name , _ in fields :
391
+ print (f", { name } ({ name } )" , file = io , end = "" )
396
392
print (" {}" , file = io )
397
393
# Step 2.3. Macro to define object type
398
394
print (
399
- f' MLC_DEF_DYN_TYPE(_EXPORTS , { cls_name } Obj, { parent_obj_name } , "{ type_info .type_key } ");' ,
395
+ f' MLC_DEF_DYN_TYPE({ export_macro } , { cls_name } Obj, { parent_obj_name } , "{ type_info .type_key } ");' ,
400
396
file = io ,
401
397
)
402
398
print (f"}}; // struct { cls_name } Obj\n " , file = io )
403
399
# Step 3. Object reference class
404
400
print (f"struct { cls_name } : public { parent_ref_name } {{" , file = io )
405
401
# Step 3.1. Define fields for reflection
406
- print (f" MLC_DEF_OBJ_REF(_EXPORTS, { cls_name } , { cls_name } Obj, { parent_ref_name } )" , file = io )
402
+ print (
403
+ f" MLC_DEF_OBJ_REF({ export_macro } , { cls_name } , { cls_name } Obj, { parent_ref_name } )" , file = io
404
+ )
407
405
for name , _ in fields :
408
406
print (f' .Field("{ name } ", &{ cls_name } Obj::{ name } )' , file = io )
409
407
# Step 3.2. Define `__init__` method for reflection
@@ -416,3 +414,34 @@ def prototype_cxx(type_info: type | TypeInfo) -> str:
416
414
for ns in reversed (namespaces [:- 1 ]):
417
415
print (f"}} // namespace { ns } " , file = io )
418
416
return io .getvalue ().rstrip ()
417
+
418
+
419
+ def prototype (
420
+ match : str | type | TypeInfo | Callable [[TypeInfo ], bool ],
421
+ lang : Literal ["c++" , "py" ] = "c++" ,
422
+ export_macro : str = "_EXPORTS" ,
423
+ ) -> str :
424
+ type_info_list : list [TypeInfo ]
425
+ if (
426
+ isinstance (match , type )
427
+ and (type_info := getattr (match , "_mlc_type_info" , None )) is not None
428
+ ):
429
+ assert isinstance (type_info , TypeInfo )
430
+ type_info_list = [type_info ]
431
+ elif isinstance (match , TypeInfo ):
432
+ type_info_list = [match ]
433
+ elif isinstance (match , str ):
434
+ pattern = re .compile (match )
435
+ type_info_list = [i for i in type_table () if i and pattern .fullmatch (i .type_key )]
436
+ elif callable (match ):
437
+ type_info_list = [i for i in type_table () if i and match (i )]
438
+ else :
439
+ raise ValueError (f"Invalid `match`: { match } " )
440
+ fn : Callable [[TypeInfo ], str ]
441
+ if lang == "c++" :
442
+ fn = functools .partial (_prototype_cxx , export_macro = export_macro )
443
+ elif lang == "py" :
444
+ fn = _prototype_py
445
+ else :
446
+ raise ValueError (f"Invalid `lang`: { lang } " )
447
+ return "\n \n " .join (fn (i ) for i in type_info_list )
0 commit comments