Skip to content

Commit

Permalink
feat(core): Support Bool in Any/AnyView (#7)
Browse files Browse the repository at this point in the history
This PR introduces `bool` support across the stack, including:
- Boolean values in `Any`/`AnyView`
- `Ref<bool>` and `Optional<bool>`
- Boolean fields and `Optional<bool>` fields in C/Python dataclasses
- Proper JSON parsing for boolean values
- Serialization/deserialization with boolean values
- Structural equal/hash with boolean values
- Boolean literals in Printer AST
  • Loading branch information
potatomashed authored Jan 19, 2025
1 parent a7b3452 commit c869192
Show file tree
Hide file tree
Showing 36 changed files with 1,181 additions and 354 deletions.
179 changes: 160 additions & 19 deletions cpp/c_api_tests.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <cstring>
#include <mlc/core/all.h>

namespace mlc {
Expand All @@ -7,6 +8,7 @@ namespace {

MLC_REGISTER_FUNC("mlc.testing.cxx_none").set_body([]() -> void { return; });
MLC_REGISTER_FUNC("mlc.testing.cxx_null").set_body([]() -> void * { return nullptr; });
MLC_REGISTER_FUNC("mlc.testing.cxx_bool").set_body([](bool x) -> bool { return x; });
MLC_REGISTER_FUNC("mlc.testing.cxx_int").set_body([](int x) -> int { return x; });
MLC_REGISTER_FUNC("mlc.testing.cxx_float").set_body([](double x) -> double { return x; });
MLC_REGISTER_FUNC("mlc.testing.cxx_ptr").set_body([](void *x) -> void * { return x; });
Expand All @@ -17,6 +19,7 @@ MLC_REGISTER_FUNC("mlc.testing.cxx_raw_str").set_body([](const char *x) { return
/**************** Reflection ****************/

struct TestingCClassObj : public Object {
bool bool_;
int8_t i8;
int16_t i16;
int32_t i32;
Expand All @@ -40,6 +43,7 @@ struct TestingCClassObj : public Object {
Dict<Any, Str> dict_any_str;
Dict<Str, List<int>> dict_str_list_int;

Optional<bool> opt_bool;
Optional<int64_t> opt_i64;
Optional<double> opt_f64;
Optional<void *> opt_raw_ptr;
Expand All @@ -57,22 +61,22 @@ struct TestingCClassObj : public Object {
Optional<Dict<Any, Str>> opt_dict_any_str;
Optional<Dict<Str, List<int>>> opt_dict_str_list_int;

explicit TestingCClassObj(int8_t i8, int16_t i16, int32_t i32, int64_t i64, float f32, double f64, void *raw_ptr,
DLDataType dtype, DLDevice device, Any any, Func func, UList ulist, UDict udict, Str str_,
Str str_readonly, List<Any> list_any, List<List<int>> list_list_int,
explicit TestingCClassObj(bool bool_, int8_t i8, int16_t i16, int32_t i32, int64_t i64, float f32, double f64,
void *raw_ptr, DLDataType dtype, DLDevice device, Any any, Func func, UList ulist,
UDict udict, Str str_, Str str_readonly, List<Any> list_any, List<List<int>> list_list_int,
Dict<Any, Any> dict_any_any, Dict<Str, Any> dict_str_any, Dict<Any, Str> dict_any_str,
Dict<Str, List<int>> dict_str_list_int, Optional<int64_t> opt_i64, Optional<double> opt_f64,
Optional<void *> opt_raw_ptr, Optional<DLDataType> opt_dtype, Optional<DLDevice> opt_device,
Optional<Func> opt_func, Optional<UList> opt_ulist, Optional<UDict> opt_udict,
Optional<Str> opt_str, Optional<List<Any>> opt_list_any,
Dict<Str, List<int>> dict_str_list_int, Optional<bool> opt_bool, Optional<int64_t> opt_i64,
Optional<double> opt_f64, Optional<void *> opt_raw_ptr, Optional<DLDataType> opt_dtype,
Optional<DLDevice> opt_device, Optional<Func> opt_func, Optional<UList> opt_ulist,
Optional<UDict> opt_udict, Optional<Str> opt_str, Optional<List<Any>> opt_list_any,
Optional<List<List<int>>> opt_list_list_int, Optional<Dict<Any, Any>> opt_dict_any_any,
Optional<Dict<Str, Any>> opt_dict_str_any, Optional<Dict<Any, Str>> opt_dict_any_str,
Optional<Dict<Str, List<int>>> opt_dict_str_list_int)
: i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device),
any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any),
list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any),
dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_i64(opt_i64), opt_f64(opt_f64),
opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func),
: bool_(bool_), i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype),
device(device), any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly),
list_any(list_any), list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any),
dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_bool(opt_bool), opt_i64(opt_i64),
opt_f64(opt_f64), opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func),
opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any),
opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any),
opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {}
Expand All @@ -84,6 +88,7 @@ struct TestingCClassObj : public Object {

struct TestingCClass : public ObjectRef {
MLC_DEF_OBJ_REF(MLC_EXPORTS, TestingCClass, TestingCClassObj, ObjectRef)
.Field("bool_", &TestingCClassObj::bool_)
.Field("i8", &TestingCClassObj::i8)
.Field("i16", &TestingCClassObj::i16)
.Field("i32", &TestingCClassObj::i32)
Expand All @@ -105,6 +110,7 @@ struct TestingCClass : public ObjectRef {
.Field("dict_str_any", &TestingCClassObj::dict_str_any)
.Field("dict_any_str", &TestingCClassObj::dict_any_str)
.Field("dict_str_list_int", &TestingCClassObj::dict_str_list_int)
.Field("opt_bool", &TestingCClassObj::opt_bool)
.Field("opt_i64", &TestingCClassObj::opt_i64)
.Field("opt_f64", &TestingCClassObj::opt_f64)
.Field("opt_raw_ptr", &TestingCClassObj::opt_raw_ptr)
Expand All @@ -121,13 +127,13 @@ struct TestingCClass : public ObjectRef {
.Field("opt_dict_any_str", &TestingCClassObj::opt_dict_any_str)
.Field("opt_dict_str_list_int", &TestingCClassObj::opt_dict_str_list_int)
.MemFn("i64_plus_one", &TestingCClassObj::i64_plus_one)
.StaticFn("__init__",
InitOf<TestingCClassObj, int8_t, int16_t, int32_t, int64_t, float, double, void *, DLDataType, DLDevice,
Any, Func, UList, UDict, Str, Str, List<Any>, List<List<int>>, Dict<Any, Any>, Dict<Str, Any>,
Dict<Any, Str>, Dict<Str, List<int>>, Optional<int64_t>, Optional<double>, Optional<void *>,
Optional<DLDataType>, Optional<DLDevice>, Optional<Func>, Optional<UList>, Optional<UDict>,
Optional<Str>, Optional<List<Any>>, Optional<List<List<int>>>, Optional<Dict<Any, Any>>,
Optional<Dict<Str, Any>>, Optional<Dict<Any, Str>>, Optional<Dict<Str, List<int>>>>);
.StaticFn("__init__", InitOf<TestingCClassObj, bool, int8_t, int16_t, int32_t, int64_t, float, double, void *,
DLDataType, DLDevice, Any, Func, UList, UDict, Str, Str, List<Any>, List<List<int>>,
Dict<Any, Any>, Dict<Str, Any>, Dict<Any, Str>, Dict<Str, List<int>>, Optional<bool>,
Optional<int64_t>, Optional<double>, Optional<void *>, Optional<DLDataType>,
Optional<DLDevice>, Optional<Func>, Optional<UList>, Optional<UDict>, Optional<Str>,
Optional<List<Any>>, Optional<List<List<int>>>, Optional<Dict<Any, Any>>,
Optional<Dict<Str, Any>>, Optional<Dict<Any, Str>>, Optional<Dict<Str, List<int>>>>);
};

/**************** Traceback ****************/
Expand Down Expand Up @@ -191,5 +197,140 @@ MLC_REGISTER_FUNC("mlc.testing.nested_type_checking_list").set_body([](Str name)
MLC_UNREACHABLE();
});

/**************** Visitor ****************/

MLC_REGISTER_FUNC("mlc.testing.VisitFields").set_body([](ObjectRef root) {
struct Visitor {
void operator()(MLCTypeField *f, const Any *any) { Push("Any", f->name, *any); }
void operator()(MLCTypeField *f, ObjectRef *obj) { Push("ObjectRef", f->name, *obj); }
void operator()(MLCTypeField *f, Optional<ObjectRef> *opt) { Push("Optional<ObjectRef>", f->name, *opt); }
void operator()(MLCTypeField *f, Optional<bool> *opt) { Push("Optional<bool>", f->name, *opt); }
void operator()(MLCTypeField *f, Optional<int64_t> *opt) { Push("Optional<int64_t>", f->name, *opt); }
void operator()(MLCTypeField *f, Optional<double> *opt) { Push("Optional<double>", f->name, *opt); }
void operator()(MLCTypeField *f, Optional<DLDevice> *opt) { Push("Optional<DLDevice>", f->name, *opt); }
void operator()(MLCTypeField *f, Optional<DLDataType> *opt) { Push("Optional<DLDataType>", f->name, *opt); }
void operator()(MLCTypeField *f, bool *v) { Push("bool", f->name, *v); }
void operator()(MLCTypeField *f, int8_t *v) { Push("int8_t", f->name, *v); }
void operator()(MLCTypeField *f, int16_t *v) { Push("int16_t", f->name, *v); }
void operator()(MLCTypeField *f, int32_t *v) { Push("int32_t", f->name, *v); }
void operator()(MLCTypeField *f, int64_t *v) { Push("int64_t", f->name, *v); }
void operator()(MLCTypeField *f, float *v) { Push("float", f->name, *v); }
void operator()(MLCTypeField *f, double *v) { Push("double", f->name, *v); }
void operator()(MLCTypeField *f, DLDataType *v) { Push("DLDataType", f->name, *v); }
void operator()(MLCTypeField *f, DLDevice *v) { Push("DLDevice", f->name, *v); }
void operator()(MLCTypeField *f, Optional<void *> *v) { Push("Optional<void *>", f->name, *v); }
void operator()(MLCTypeField *f, void **v) { Push("void *", f->name, *v); }
void operator()(MLCTypeField *f, const char **v) { Push("const char *", f->name, *v); }

void Push(const char *ty, const char *name, Any value) {
types->push_back(ty);
names->push_back(name);
values->push_back(value);
}
List<Str> *types;
List<Str> *names;
UList *values;
};
List<Str> types;
List<Str> names;
UList values;
MLCTypeInfo *info = ::mlc::Lib::GetTypeInfo(root.GetTypeIndex());
::mlc::core::VisitFields(root.get(), info, Visitor{&types, &names, &values});
return UList{types, names, values};
});

struct FieldFoundException : public ::std::exception {};

struct FieldGetter {
void operator()(MLCTypeField *f, const Any *any) { Check(f->name, any); }
void operator()(MLCTypeField *f, ObjectRef *obj) { Check(f->name, obj); }
void operator()(MLCTypeField *f, Optional<ObjectRef> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<bool> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<int64_t> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<double> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<DLDevice> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<DLDataType> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, bool *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, int8_t *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, int16_t *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, int32_t *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, int64_t *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, float *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, double *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, DLDataType *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, DLDevice *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, Optional<void *> *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, void **v) { Check(f->name, v); }
void operator()(MLCTypeField *f, const char **v) { Check(f->name, v); }

template <typename T> void Check(const char *name, T *v) {
if (std::strcmp(name, target_name) == 0) {
*ret = Any(*v);
throw FieldFoundException();
}
}
const char *target_name;
Any *ret;
};

struct FieldSetter {
void operator()(MLCTypeField *f, Any *any) { Check(f->name, any); }
void operator()(MLCTypeField *f, ObjectRef *obj) { Check(f->name, obj); }
void operator()(MLCTypeField *f, Optional<ObjectRef> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<bool> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<int64_t> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<double> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<DLDevice> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, Optional<DLDataType> *opt) { Check(f->name, opt); }
void operator()(MLCTypeField *f, bool *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, int8_t *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, int16_t *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, int32_t *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, int64_t *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, float *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, double *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, DLDataType *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, DLDevice *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, Optional<void *> *v) { Check(f->name, v); }
void operator()(MLCTypeField *f, void **v) { Check(f->name, v); }
void operator()(MLCTypeField *f, const char **v) { Check(f->name, v); }

template <typename T> void Check(const char *name, T *v) {
if (std::strcmp(name, target_name) == 0) {
if constexpr (std::is_same_v<T, Any>) {
*v = src;
} else {
*v = src.operator T();
}
throw FieldFoundException();
}
}
const char *target_name;
Any src;
};

MLC_REGISTER_FUNC("mlc.testing.FieldGet").set_body([](ObjectRef root, const char *target_name) {
Any ret;
MLCTypeInfo *info = ::mlc::Lib::GetTypeInfo(root.GetTypeIndex());
try {
::mlc::core::VisitFields(root.get(), info, FieldGetter{target_name, &ret});
} catch (FieldFoundException &) {
return ret;
}
MLC_THROW(ValueError) << "Field not found: " << target_name;
MLC_UNREACHABLE();
});

MLC_REGISTER_FUNC("mlc.testing.FieldSet").set_body([](ObjectRef root, const char *target_name, Any src) {
MLCTypeInfo *info = ::mlc::Lib::GetTypeInfo(root.GetTypeIndex());
try {
::mlc::core::VisitFields(root.get(), info, FieldSetter{target_name, src});
} catch (FieldFoundException &) {
return;
}
MLC_THROW(ValueError) << "Field not found: " << target_name;
MLC_UNREACHABLE();
});

} // namespace
} // namespace mlc
48 changes: 32 additions & 16 deletions cpp/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#include <sstream>

namespace mlc {
namespace core {
namespace {

using mlc::core::TopoVisit;
using mlc::core::VisitFields;

mlc::Str Serialize(Any any);
Any Deserialize(const char *json_str, int64_t json_str_len);
Any JSONLoads(const char *json_str, int64_t json_str_len);
Expand Down Expand Up @@ -35,11 +37,13 @@ inline mlc::Str Serialize(Any any) {
// clang-format off
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<bool> *opt) { if (const bool *v = opt->get()) EmitBool(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); }
// clang-format on
MLC_INLINE void operator()(MLCTypeField *, bool *v) { EmitBool(*v); }
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast<int64_t>(*v)); }
Expand All @@ -56,6 +60,7 @@ inline mlc::Str Serialize(Any any) {
MLC_THROW(TypeError) << "Unserializable type: const char *";
}
inline void EmitNil() { (*os) << ", null"; }
inline void EmitBool(bool v) { (*os) << ", " << (v ? "true" : "false"); }
inline void EmitFloat(double v) { (*os) << ", " << std::fixed << std::setprecision(19) << v; }
inline void EmitInt(int64_t v) {
int32_t type_int = (*get_json_type_index)(TypeTraits<int64_t>::type_str);
Expand All @@ -73,6 +78,8 @@ inline mlc::Str Serialize(Any any) {
int32_t type_index = any->type_index;
if (type_index == kMLCNone) {
EmitNil();
} else if (type_index == kMLCBool) {
EmitBool(any->operator bool());
} else if (type_index == kMLCInt) {
EmitInt(any->operator int64_t());
} else if (type_index == kMLCFloat) {
Expand Down Expand Up @@ -144,6 +151,9 @@ inline mlc::Str Serialize(Any any) {
TopoVisit(any.operator Object *(), nullptr, on_visit);
} else if (any.type_index == kMLCNone) {
os << "null";
} else if (any.type_index == kMLCBool) {
bool v = any.operator bool();
os << (v ? "true" : "false");
} else if (any.type_index == kMLCInt) {
int32_t type_int = get_json_type_index(TypeTraits<int64_t>::type_str);
int64_t v = any;
Expand Down Expand Up @@ -211,7 +221,8 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
}
} else if (arg.type_index == kMLCList) {
list[j] = invoke_init(arg.operator UList());
} else if (arg.type_index == kMLCStr || arg.type_index == kMLCFloat || arg.type_index == kMLCNone) {
} else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
arg.type_index == kMLCNone) {
// Do nothing
} else {
MLC_THROW(ValueError) << "Unexpected value: " << arg;
Expand All @@ -223,6 +234,7 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
values[i] = values[k];
} else if (obj.type_index == kMLCStr) {
// Do nothing
// TODO: how about kMLCBool, kMLCFloat, kMLCNone?
} else {
MLC_THROW(ValueError) << "Unexpected value: " << obj;
}
Expand Down Expand Up @@ -277,10 +289,10 @@ inline Any JSONLoads(const char *json_str, int64_t json_str_len) {
Any ParseBoolean() {
if (PeekChar() == 't') {
ExpectString("true", 4);
return Any(1);
return Any(true);
} else {
ExpectString("false", 5);
return Any(0);
return Any(false);
}
}

Expand Down Expand Up @@ -479,23 +491,27 @@ inline Any JSONLoads(const char *json_str, int64_t json_str_len) {
}
return JSONParser{0, json_str_len, json_str}.Parse();
}
} // namespace
} // namespace mlc

MLC_REGISTER_FUNC("mlc.core.JSONLoads").set_body([](AnyView json_str) {
namespace mlc {

Any JSONLoads(AnyView json_str) {
if (json_str.type_index == kMLCRawStr) {
return ::mlc::core::JSONLoads(json_str.operator const char *());
return JSONLoads(json_str.operator const char *());
} else {
::mlc::Str str = json_str;
return ::mlc::core::JSONLoads(str);
return JSONLoads(json_str.operator Str());
}
});
MLC_REGISTER_FUNC("mlc.core.JSONSerialize").set_body(::mlc::core::Serialize);
MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) {
}

Any JSONDeserialize(AnyView json_str) {
if (json_str.type_index == kMLCRawStr) {
return ::mlc::core::Deserialize(json_str.operator const char *());
return Deserialize(json_str.operator const char *());
} else {
return ::mlc::core::Deserialize(json_str.operator ::mlc::Str());
return Deserialize(json_str.operator ::mlc::Str());
}
});
} // namespace
} // namespace core
}

Str JSONSerialize(AnyView source) { return Serialize(source); }

} // namespace mlc
Loading

0 comments on commit c869192

Please sign in to comment.