Skip to content

Commit a8017ce

Browse files
authored
feat(dataclass): Introduce mlcd.prototype (#19)
This PR introduces `mlcd.prototype` that aims to help easier migration between C++ and Python. The method `prototype` prints, for example, when `lang="c++"`, the C++ definition of existing Python dataclasses defined in MLC API. More specifically, the command below: ``` python -c "import mlc.dataclasses as mlcd; print(mlcd.prototype(\"mlc.sym.*\", lang=\"c++\" export_macro=\"MLC_SYM_EXPORTS\"))" ``` prints all C++ definition code for types with prefix `mlc.sym.*`, and the export macro is `MLC_SYM_EXPORTS`.
1 parent b584462 commit a8017ce

File tree

6 files changed

+63
-23
lines changed

6 files changed

+63
-23
lines changed

python/mlc/_cython/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
type_key2py_type_info,
5959
type_register_fields,
6060
type_register_structure,
61+
type_table,
6162
)
6263

6364
LIB: _ctypes.CDLL = _core.LIB

python/mlc/_cython/core.pyx

+4
Original file line numberDiff line numberDiff line change
@@ -1656,6 +1656,10 @@ def type_create(int32_t parent_type_index, str type_key):
16561656
return type_info
16571657

16581658

1659+
cpdef list type_table():
1660+
return list(TYPE_INDEX_TO_INFO)
1661+
1662+
16591663
cdef const char* _DLPACK_CAPSULE_NAME = "dltensor"
16601664
cdef const char* _DLPACK_CAPSULE_NAME_USED = "used_dltensor"
16611665
cdef const char* _DLPACK_CAPSULE_NAME_VER = "dltensor_versioned"

python/mlc/dataclasses/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
from .c_class import c_class
22
from .py_class import PyClass, py_class
3-
from .utils import Structure, add_vtable_method, field, prototype_cxx, prototype_py, vtable_method
3+
from .utils import (
4+
Structure,
5+
add_vtable_method,
6+
field,
7+
prototype,
8+
vtable_method,
9+
)

python/mlc/dataclasses/c_class.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
get_parent_type,
1919
inspect_dataclass_fields,
2020
method_init,
21-
prototype_py,
21+
prototype,
2222
)
2323

2424
ClsType = typing.TypeVar("ClsType")
@@ -117,5 +117,5 @@ def _check_c_class(
117117
if warned:
118118
warnings.warn(
119119
f"One or multiple warnings in `{type_cls.__module__}.{type_cls.__qualname__}`. Its prototype is:\n"
120-
+ prototype_py(type_info)
120+
+ prototype(type_info, lang="py")
121121
)

python/mlc/dataclasses/utils.py

+46-17
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import ctypes
44
import dataclasses
5+
import functools
56
import inspect
7+
import re
68
import typing
79
from collections.abc import Callable
810
from io import StringIO
@@ -16,6 +18,7 @@
1618
TypeMethod,
1719
type_add_method,
1820
type_index2type_methods,
21+
type_table,
1922
)
2023
from mlc.core import typing as mlc_typing
2124

@@ -325,13 +328,9 @@ def add_vtable_methods_for_type_cls(type_cls: type, type_index: int) -> None:
325328
type_add_method(type_index, name, func, kind=0)
326329

327330

328-
def prototype_py(type_info: type | TypeInfo) -> str:
329-
if not isinstance(type_info, TypeInfo):
330-
if (type_info := getattr(type_info, "_mlc_type_info", None)) is None: # type: ignore[assignment]
331-
raise ValueError(f"Invalid type: {type_info}")
331+
def _prototype_py(type_info: TypeInfo) -> str:
332332
assert isinstance(type_info, TypeInfo)
333333
cls_name = type_info.type_key.rsplit(".", maxsplit=1)[-1]
334-
335334
io = StringIO()
336335
print(f"@mlc.dataclasses.c_class({type_info.type_key!r})", file=io)
337336
print(f"class {cls_name}:", file=io)
@@ -352,12 +351,11 @@ def prototype_py(type_info: type | TypeInfo) -> str:
352351
return io.getvalue().rstrip()
353352

354353

355-
def prototype_cxx(type_info: type | TypeInfo) -> str:
356-
if not isinstance(type_info, TypeInfo):
357-
if (type_info := getattr(type_info, "_mlc_type_info", None)) is None: # type: ignore[assignment]
358-
raise ValueError(f"Invalid type: {type_info}")
354+
def _prototype_cxx(
355+
type_info: TypeInfo,
356+
export_macro: str = "_EXPORTS",
357+
) -> str:
359358
assert isinstance(type_info, TypeInfo)
360-
361359
parent_type_info = type_info.get_parent()
362360
namespaces = type_info.type_key.split(".")
363361
cls_name = namespaces[-1]
@@ -388,22 +386,22 @@ def prototype_cxx(type_info: type | TypeInfo) -> str:
388386
if i != 0:
389387
print(", ", file=io, end="")
390388
print(f"{ty} {name}", file=io, end="")
391-
print("): ", file=io, end="")
392-
for i, (name, _) in enumerate(fields):
393-
if i != 0:
394-
print(", ", file=io, end="")
395-
print(f"{name}({name})", file=io, end="")
389+
print("): _mlc_header{}", file=io, end="")
390+
for name, _ in fields:
391+
print(f", {name}({name})", file=io, end="")
396392
print(" {}", file=io)
397393
# Step 2.3. Macro to define object type
398394
print(
399-
f' MLC_DEF_DYN_TYPE(_EXPORTS, {cls_name}Obj, {parent_obj_name}, "{type_info.type_key}");',
395+
f' MLC_DEF_DYN_TYPE({export_macro}, {cls_name}Obj, {parent_obj_name}, "{type_info.type_key}");',
400396
file=io,
401397
)
402398
print(f"}}; // struct {cls_name}Obj\n", file=io)
403399
# Step 3. Object reference class
404400
print(f"struct {cls_name} : public {parent_ref_name} {{", file=io)
405401
# Step 3.1. Define fields for reflection
406-
print(f" MLC_DEF_OBJ_REF(_EXPORTS, {cls_name}, {cls_name}Obj, {parent_ref_name})", file=io)
402+
print(
403+
f" MLC_DEF_OBJ_REF({export_macro}, {cls_name}, {cls_name}Obj, {parent_ref_name})", file=io
404+
)
407405
for name, _ in fields:
408406
print(f' .Field("{name}", &{cls_name}Obj::{name})', file=io)
409407
# Step 3.2. Define `__init__` method for reflection
@@ -416,3 +414,34 @@ def prototype_cxx(type_info: type | TypeInfo) -> str:
416414
for ns in reversed(namespaces[:-1]):
417415
print(f"}} // namespace {ns}", file=io)
418416
return io.getvalue().rstrip()
417+
418+
419+
def prototype(
420+
match: str | type | TypeInfo | Callable[[TypeInfo], bool],
421+
lang: Literal["c++", "py"] = "c++",
422+
export_macro: str = "_EXPORTS",
423+
) -> str:
424+
type_info_list: list[TypeInfo]
425+
if (
426+
isinstance(match, type)
427+
and (type_info := getattr(match, "_mlc_type_info", None)) is not None
428+
):
429+
assert isinstance(type_info, TypeInfo)
430+
type_info_list = [type_info]
431+
elif isinstance(match, TypeInfo):
432+
type_info_list = [match]
433+
elif isinstance(match, str):
434+
pattern = re.compile(match)
435+
type_info_list = [i for i in type_table() if i and pattern.fullmatch(i.type_key)]
436+
elif callable(match):
437+
type_info_list = [i for i in type_table() if i and match(i)]
438+
else:
439+
raise ValueError(f"Invalid `match`: {match}")
440+
fn: Callable[[TypeInfo], str]
441+
if lang == "c++":
442+
fn = functools.partial(_prototype_cxx, export_macro=export_macro)
443+
elif lang == "py":
444+
fn = _prototype_py
445+
else:
446+
raise ValueError(f"Invalid `lang`: {lang}")
447+
return "\n\n".join(fn(i) for i in type_info_list)

tests/python/test_dataclasses_prototype.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class py_class:
4545
opt_dict_any_str: dict[Any, str] | None
4646
opt_dict_str_list_int: dict[str, list[int]] | None
4747
""".strip()
48-
actual = mlcd.prototype_py(PyClassForTest).strip()
48+
actual = mlcd.prototype(PyClassForTest, lang="py").strip()
4949
assert actual == expected
5050

5151

@@ -93,7 +93,7 @@ def test_prototype_cxx() -> None:
9393
::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::Any>> opt_dict_str_any;
9494
::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Str>> opt_dict_any_str;
9595
::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>>> opt_dict_str_list_int;
96-
explicit py_classObj(bool bool_, int64_t i8, int64_t i16, int64_t i32, int64_t i64, double f32, double f64, void* raw_ptr, DLDataType dtype, DLDevice device, ::mlc::Any any, ::mlc::Func func, ::mlc::List<::mlc::Any> ulist, ::mlc::Dict<::mlc::Any, ::mlc::Any> udict, ::mlc::Str str_, ::mlc::Str str_readonly, ::mlc::List<::mlc::Any> list_any, ::mlc::List<::mlc::List<int64_t>> list_list_int, ::mlc::Dict<::mlc::Any, ::mlc::Any> dict_any_any, ::mlc::Dict<::mlc::Str, ::mlc::Any> dict_str_any, ::mlc::Dict<::mlc::Any, ::mlc::Str> dict_any_str, ::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>> dict_str_list_int, ::mlc::Optional<bool> opt_bool, ::mlc::Optional<int64_t> opt_i64, ::mlc::Optional<double> opt_f64, ::mlc::Optional<void*> opt_raw_ptr, ::mlc::Optional<DLDataType> opt_dtype, ::mlc::Optional<DLDevice> opt_device, ::mlc::Optional<::mlc::Func> opt_func, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_ulist, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_udict, ::mlc::Optional<::mlc::Str> opt_str, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_list_any, ::mlc::Optional<::mlc::List<::mlc::List<int64_t>>> opt_list_list_int, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_dict_any_any, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::Any>> opt_dict_str_any, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Str>> opt_dict_any_str, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>>> opt_dict_str_list_int): bool_(bool_), i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device), any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any), list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any), dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_bool(opt_bool), opt_i64(opt_i64), opt_f64(opt_f64), opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func), opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any), opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any), opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {}
96+
explicit py_classObj(bool bool_, int64_t i8, int64_t i16, int64_t i32, int64_t i64, double f32, double f64, void* raw_ptr, DLDataType dtype, DLDevice device, ::mlc::Any any, ::mlc::Func func, ::mlc::List<::mlc::Any> ulist, ::mlc::Dict<::mlc::Any, ::mlc::Any> udict, ::mlc::Str str_, ::mlc::Str str_readonly, ::mlc::List<::mlc::Any> list_any, ::mlc::List<::mlc::List<int64_t>> list_list_int, ::mlc::Dict<::mlc::Any, ::mlc::Any> dict_any_any, ::mlc::Dict<::mlc::Str, ::mlc::Any> dict_str_any, ::mlc::Dict<::mlc::Any, ::mlc::Str> dict_any_str, ::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>> dict_str_list_int, ::mlc::Optional<bool> opt_bool, ::mlc::Optional<int64_t> opt_i64, ::mlc::Optional<double> opt_f64, ::mlc::Optional<void*> opt_raw_ptr, ::mlc::Optional<DLDataType> opt_dtype, ::mlc::Optional<DLDevice> opt_device, ::mlc::Optional<::mlc::Func> opt_func, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_ulist, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_udict, ::mlc::Optional<::mlc::Str> opt_str, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_list_any, ::mlc::Optional<::mlc::List<::mlc::List<int64_t>>> opt_list_list_int, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_dict_any_any, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::Any>> opt_dict_str_any, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Str>> opt_dict_any_str, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>>> opt_dict_str_list_int): _mlc_header{}, bool_(bool_), i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device), any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any), list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any), dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_bool(opt_bool), opt_i64(opt_i64), opt_f64(opt_f64), opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func), opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any), opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any), opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {}
9797
MLC_DEF_DYN_TYPE(_EXPORTS, py_classObj, ::mlc::Object, "mlc.testing.py_class");
9898
}; // struct py_classObj
9999
@@ -142,5 +142,5 @@ def test_prototype_cxx() -> None:
142142
} // namespace testing
143143
} // namespace mlc
144144
""".strip()
145-
actual = mlcd.prototype_cxx(PyClassForTest).strip()
145+
actual = mlcd.prototype(PyClassForTest, lang="c++").strip()
146146
assert actual == expected

0 commit comments

Comments
 (0)