Skip to content

Add FP8 MoE for turbomind #3601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 137 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
137 commits
Select commit Hold shift + click to select a range
404263d
low level abstraction
lzhangzz Mar 27, 2025
81bfa75
refactor
lzhangzz Apr 2, 2025
770b85d
eliminate template
lzhangzz Apr 7, 2025
e3a9619
remove unused
lzhangzz Apr 7, 2025
6b9a433
refactor bindings
lzhangzz Apr 7, 2025
613aeec
simplify lm head
lzhangzz Apr 7, 2025
e3fe34c
refactor weight
lzhangzz Apr 8, 2025
1e057d1
fix tp
lzhangzz Apr 8, 2025
40e9097
cublas
lzhangzz Apr 8, 2025
6fc9cc9
Merge remote-tracking branch 'origin/main' into core
lzhangzz Apr 9, 2025
ff3b5f7
refactor sampling
lzhangzz Apr 10, 2025
06ff641
remove unused
lzhangzz Apr 10, 2025
14a7f45
simplify
lzhangzz Apr 11, 2025
096155c
fix AWQ support
lzhangzz Apr 11, 2025
5fd35ae
fix moe
lzhangzz Apr 11, 2025
0c5ef46
fix nccl lm_head
lzhangzz Apr 11, 2025
c2020b2
fix
lzhangzz Apr 11, 2025
510675c
refactor data types
lzhangzz Apr 15, 2025
00b121e
skip legacy ut
lzhangzz Apr 15, 2025
88d17d4
simplify
lzhangzz Apr 15, 2025
699c24f
rename data types
lzhangzz Apr 15, 2025
3ffd070
refactor
lzhangzz Apr 15, 2025
eed6bfb
refactor runtime states
lzhangzz Apr 16, 2025
d2ec3af
fix msvc build
lzhangzz Apr 16, 2025
2529631
fix msvc build
lzhangzz Apr 16, 2025
1d77856
fix msvc build
lzhangzz Apr 16, 2025
6e728cf
fix msvc build
lzhangzz Apr 16, 2025
1b6a80d
fix msvc build
lzhangzz Apr 16, 2025
0d976d3
fix msvc build
lzhangzz Apr 16, 2025
18e7602
fix msvc build
lzhangzz Apr 16, 2025
7fec496
fix msvc build
lzhangzz Apr 16, 2025
69b1841
fix msvc build
lzhangzz Apr 16, 2025
c8bc36d
format
lzhangzz Apr 16, 2025
8161c0d
remove unused
lzhangzz Apr 16, 2025
7459992
fix msvc build
lzhangzz Apr 16, 2025
d38421f
fix msvc build
lzhangzz Apr 16, 2025
7d6ab03
fix msvc build
lzhangzz Apr 16, 2025
b214a0e
fix msvc build
lzhangzz Apr 16, 2025
3ab38ca
fix msvc build
lzhangzz Apr 16, 2025
f394ef0
fix msvc build
lzhangzz Apr 16, 2025
105f1cc
fix msvc build
lzhangzz Apr 16, 2025
5ccf30c
fix msvc build
lzhangzz Apr 16, 2025
b59620c
fix msvc build
lzhangzz Apr 16, 2025
a243da0
fix msvc build
lzhangzz Apr 16, 2025
42172d3
fix msvc build
lzhangzz Apr 16, 2025
4d9910a
fix msvc build
lzhangzz Apr 16, 2025
bf7c213
fix msvc build
lzhangzz Apr 16, 2025
8651edd
fix msvc build
lzhangzz Apr 16, 2025
4788b80
fix msvc build
lzhangzz Apr 16, 2025
529225d
fix msvc build
lzhangzz Apr 16, 2025
6fd5f72
fix msvc build
lzhangzz Apr 17, 2025
86d4e86
fix msvc build
lzhangzz Apr 17, 2025
8ea7e20
fix ut & msvc build
lzhangzz Apr 17, 2025
dcf9669
fix ut & msvc build
lzhangzz Apr 17, 2025
7f8974b
fix gcc build
lzhangzz Apr 17, 2025
646813b
fix lint & ut
lzhangzz Apr 17, 2025
98f4840
fix lint
lzhangzz Apr 17, 2025
ea07957
fetch Catch2 when building tests
lzhangzz Apr 17, 2025
5d69923
rewind msvc build
lzhangzz Apr 17, 2025
d0079b5
fix sampling
lzhangzz Apr 18, 2025
15b4007
fp8 round trip test
lzhangzz Apr 21, 2025
72fa3ee
pseudo quant test
lzhangzz Apr 22, 2025
c4de357
initial sm90 gemm kernel
lzhangzz Apr 24, 2025
26b270a
optimize smem desc
lzhangzz Apr 24, 2025
68ac168
multiple warp groups
lzhangzz Apr 24, 2025
4e573c5
optimize smem desc
lzhangzz Apr 24, 2025
4b801eb
flush TMA ops
lzhangzz Apr 24, 2025
97407ad
TMA epilogue
lzhangzz Apr 24, 2025
a97a3bb
clean-up
lzhangzz Apr 24, 2025
b26d7c8
launch config & fence operand
lzhangzz Apr 24, 2025
24520ee
tuning
lzhangzz Apr 24, 2025
eb7a50e
minor
lzhangzz Apr 24, 2025
c477918
TMA multicast
lzhangzz Apr 25, 2025
04257a6
pipeline
lzhangzz Apr 25, 2025
da5b22d
warp specialization
lzhangzz Apr 25, 2025
c728db2
persistent kernel
lzhangzz Apr 27, 2025
7e03761
better TMA multicast
lzhangzz Apr 28, 2025
4fb3799
initial fp8 gemm
lzhangzz Apr 29, 2025
27f1d29
optimize
lzhangzz Apr 29, 2025
876fbe7
optimize
lzhangzz Apr 30, 2025
316f2f8
cluster boundary check
lzhangzz Apr 30, 2025
fb9dd9a
slow
lzhangzz Apr 30, 2025
dc6f660
revert
lzhangzz Apr 30, 2025
fd515e1
v2
lzhangzz Apr 30, 2025
f0e4e4e
optimize
lzhangzz May 1, 2025
4bb1033
fix scaling
lzhangzz May 6, 2025
87ab9df
fix scaling
lzhangzz May 6, 2025
c1ec1de
optimize
lzhangzz May 7, 2025
45b8218
TMA store swizzle
lzhangzz May 7, 2025
26b6e21
better TMA store swizzle
lzhangzz May 8, 2025
49e6ac9
clean up
lzhangzz May 8, 2025
c2e3564
prefetch U
lzhangzz May 8, 2025
42571c2
prefetch V
lzhangzz May 8, 2025
108e2aa
fix V for multi wg
lzhangzz May 9, 2025
89fbbb7
fix U stride for TMA
lzhangzz May 9, 2025
464a2bc
qwen3 dense fp8
lzhangzz May 13, 2025
0653c65
fix tma multicast
lzhangzz May 13, 2025
c14fcc4
fix producer register count
lzhangzz May 13, 2025
dfbae20
decouple epilogue
lzhangzz May 13, 2025
731f2c5
warpspecialized pingpong
lzhangzz May 15, 2025
77ea397
fix multicast
lzhangzz May 15, 2025
2d991a1
cluster layout
lzhangzz May 15, 2025
86344c7
update
lzhangzz May 16, 2025
56e181b
optimize
lzhangzz May 16, 2025
0a422d3
larger tiles
lzhangzz May 16, 2025
a004cd7
multicast U
lzhangzz May 19, 2025
07c181d
rename
lzhangzz May 19, 2025
b18f336
v3
lzhangzz May 19, 2025
88018a6
optimize v3
lzhangzz May 19, 2025
5959297
tune
lzhangzz May 20, 2025
a3022a3
`tensormap.replace`
lzhangzz May 21, 2025
7eda414
refactor
lzhangzz May 22, 2025
2dc68c2
v4
lzhangzz May 23, 2025
5ccf1d8
init moe support
lzhangzz May 26, 2025
d98be0f
fix
lzhangzz May 27, 2025
2de6d70
fix v stride
lzhangzz May 27, 2025
5ef84ee
fix empty tile & unaligned U
lzhangzz May 27, 2025
7eb5f69
multicast schedule
lzhangzz May 27, 2025
bac2803
scheduling
lzhangzz May 28, 2025
062872d
fix scheduling
lzhangzz May 28, 2025
493ab50
tune
lzhangzz May 29, 2025
e3f3818
fix non-grouped gemm
lzhangzz May 29, 2025
e6119a3
tune group gemm
lzhangzz May 29, 2025
cfacea2
load weight by pointers
lzhangzz May 29, 2025
90ed985
fp8 moe
lzhangzz May 30, 2025
4a58024
Merge remote-tracking branch 'origin/main' into gemm3
lzhangzz May 30, 2025
036d67a
switch to https git
lzhangzz Jun 2, 2025
8474114
fix cutlass tag
lzhangzz Jun 2, 2025
e4b4c49
90 -> 90a
lzhangzz Jun 2, 2025
75eb1ed
fix missing headers
lzhangzz Jun 2, 2025
3493b2d
fix cuda-11
lzhangzz Jun 2, 2025
ce5ac45
guard sm90
lzhangzz Jun 2, 2025
863be32
update
lzhangzz Jun 2, 2025
11289e4
v5
lzhangzz Jun 2, 2025
d31e65d
update
lzhangzz Jun 2, 2025
2fa2ad5
update
lzhangzz Jun 2, 2025
19b93d2
v5
lzhangzz Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,6 @@ option(BUILD_FAST_MATH "Build in fast math mode" ON)
include(FetchContent)

if (BUILD_TEST)
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG 6f47420213f757831fae65c686aa471749fa8d60
GIT_SHALLOW ON
)

set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")

FetchContent_MakeAvailable(repo-cutlass)

set(CUTLASS_HEADER_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include)
set(CUTLASS_EXTENSIONS_DIR ${PROJECT_SOURCE_DIR}/src/turbomind/cutlass_extensions/include)


FetchContent_Declare(
Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
Expand All @@ -61,6 +46,19 @@ if (BUILD_TEST)
FetchContent_MakeAvailable(Catch2)
endif()


FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG v3.9.2
GIT_SHALLOW ON
)

set(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES ON CACHE BOOL "Enable extended GMMA shapes")
set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")

FetchContent_MakeAvailable(repo-cutlass)

FetchContent_Declare(
yaml-cpp
GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git
Expand Down Expand Up @@ -129,10 +127,13 @@ if (NOT CMAKE_CUDA_ARCHITECTURES)
list(APPEND CMAKE_CUDA_ARCHITECTURES 86-real)
endif ()
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-real)
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
endif ()
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "12.0")
list(APPEND CMAKE_CUDA_ARCHITECTURES 90a-real)
endif ()
if (MSVC)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90-real)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90a-real)
endif ()
endif ()

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def model_format(parser, default: str = None):
return parser.add_argument('--model-format',
type=str,
default=default,
choices=['hf', 'awq', 'gptq'],
choices=['hf', 'awq', 'gptq', 'fp8'],
help='The format of input model. `hf` means `hf_llama`, '
'`awq` represents the quantized model by AWQ,'
' and `gptq` refers to the quantized model by GPTQ')
Expand Down
9 changes: 7 additions & 2 deletions lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .source_model.base import INPUT_MODELS
from .target_model.base import OUTPUT_MODELS, BaseOutputModel

SUPPORTED_FORMATS = ['hf', 'awq', 'gptq', None]
SUPPORTED_FORMATS = ['hf', 'awq', 'gptq', 'fp8', None]
logger = get_logger('lmdeploy')


Expand Down Expand Up @@ -102,6 +102,9 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s
if model_format in ['awq', 'gptq']:
weight_type = 'int4'
group_size = 128 if group_size == 0 else group_size
elif model_format == 'fp8':
weight_type = 'fp8'
group_size = 128
else:
torch_dtype = getattr(model_config, 'torch_dtype', 'float16')
TORCH_DTYPE_MAP = {torch.bfloat16: 'bfloat16', torch.float16: 'float16'}
Expand All @@ -112,7 +115,7 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s
weight_type = 'bfloat16'

if dtype == 'auto':
weight_type = weight_type if weight_type in ['float16', 'bfloat16', 'int4'] else 'float16'
weight_type = weight_type if weight_type in ['float16', 'bfloat16', 'int4', 'fp8'] else 'float16'
elif dtype in ['float16', 'bfloat16']:
if weight_type == 'int4':
logger.warning(f'The model {model_path} is a quantized model, so the '
Expand Down Expand Up @@ -197,6 +200,8 @@ def get_tm_model(model_path,
assert not quant_config.get('desc_act', False) and \
quant_config.get('sym', True), \
f'unsupported quant config: {quant_config}'
elif quant_method == 'fp8':
pass
else:
assert 0, f'unsupported quant_config: {quant_config}'

Expand Down
20 changes: 12 additions & 8 deletions lmdeploy/turbomind/deploy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ def __init__(self, model: BaseOutputModel):
self.inter_size = model.model_config.inter_size
self.group_size = max(1, model.model_config.group_size)

def _export(self, inter_size: int, fmt: str, idx: int, w123, kind: str, pack_fn, apply_gs=False):
def _export(self, inter_size: int, fmt: str, idx: int, w123, kind: str, pack_fn, apply_gs=False, block_size=1, **kwargs):
is_lora_a, is_lora_b = get_lora_flags(kind)
w1, w2, w3 = map(transpose, w123)

if not is_lora_a:
# TODO: handle padding for block_size != 1
if not is_lora_a and block_size == 1:
w1 = pad_out_dims(w1, inter_size)
w3 = pad_out_dims(w3, inter_size)
if not is_lora_b:
if not is_lora_b and block_size == 1:
group_size = self.group_size if apply_gs else 1
w2 = pad_in_dims(w2, inter_size // group_size)

Expand Down Expand Up @@ -171,12 +172,15 @@ def __init__(self, model: BaseOutputModel):
self.attn_bias = model.model_config.attn_bias
self.qk_norm = model.model_config.qk_norm

def _reorder_and_merge(self, qkvo):
def _reorder_and_merge(self, qkvo, block_size):
q, k, v, o = qkvo
# reorder output dim for tm's rotary embedding layout
if self.model.permute_qk:
q = permute_v2(q, self.head_dim)
k = permute_v2(k, self.head_dim)
if block_size == 1:
q = permute_v2(q, self.head_dim)
k = permute_v2(k, self.head_dim)
else:
assert block_size % self.head_dim == 0
qkv = merge_qkv_v2(q, k, v, self.tp)
# zero bias for `wo` when `w_qkv` has bias but `wo` doesn't
if o is None and q.dim() == 1:
Expand Down Expand Up @@ -204,7 +208,7 @@ def _repeat(x):

return (q, k, v, o)

def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs):
def _export(self, idx: int, qkvo, kind: str, pack_fn, block_size=1, **kwargs):
if all(x is None for x in qkvo):
return
is_lora_a, is_lora_b = get_lora_flags(kind)
Expand All @@ -214,7 +218,7 @@ def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs):
qkvo = tuple(map(transpose, qkvo))
if self.model.repeat_kv:
qkvo = self._repeat_kv(qkvo, kind)
qkv, o = self._reorder_and_merge(qkvo)
qkv, o = self._reorder_and_merge(qkvo, block_size)
self.model.save_split(pack_fn(qkv),
self._attn.format(idx, 'w_qkv', kind),
split_dim=-1,
Expand Down
20 changes: 20 additions & 0 deletions lmdeploy/turbomind/deploy/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ def to_half(x: torch.Tensor):
return x.to(torch.half)


def to_float(x: torch.Tensor):
return x.to(torch.float)


def to_fp8(x: torch.Tensor):
assert x.dtype == torch.uint8
return x.view(dtype=torch.float8_e4m3fn)


def pack_u4_row(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.uint8
xs = x.view(*x.shape[:-1], -1, 8).split(1, dim=-1)
Expand Down Expand Up @@ -51,6 +60,15 @@ def __call__(self, f, g, i):
f(i, g('qzeros'), 'zeros', to_half, apply_gs=True)


class WeightScaleInv(Parameter):
KEYS = '.weight_scale_inv', '.weight'

# TODO: flag any operations crossing the quant blocks as illegal
def __call__(self, f, g, i):
f(i, g('weight_scale_inv'), 'scales', to_float, block_size=128)
f(i, g('weight'), 'weight', identity)


class Weight(Parameter):
KEYS = '.weight',

Expand Down Expand Up @@ -79,6 +97,8 @@ def get_params(keys: List[str], bias=0):
ps.append(PLora())
if QuantWeightOnly.take(keys):
ps.append(QuantWeightOnly())
if WeightScaleInv.take(keys):
ps.append(WeightScaleInv())
if Weight.take(keys):
ps.append(Weight())
if bias and Bias.take(keys):
Expand Down
13 changes: 13 additions & 0 deletions lmdeploy/turbomind/deploy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,23 @@ def process_gptq(x: torch.Tensor, kind: str):
return x


def process_fp8(x: torch.Tensor, kind: str):
x = x.cuda()
if x.dtype == torch.float8_e4m3fn:
# some ops (e.g. torch.cat) for fp8 is not implemented in pytorch
return x.view(dtype=torch.uint8)
elif kind != 'weight_scale_inv' and x.dtype == torch.float:
return x.to(dtype=torch.bfloat16)
else:
return x


def get_input_policy(model_format):
if model_format == 'awq':
return process_awq_gemm
elif model_format == 'gptq':
return process_gptq
elif model_format == 'fp8':
return process_fp8
else:
return to_cuda
10 changes: 5 additions & 5 deletions lmdeploy/turbomind/deploy/source_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def filter(self, pattern: str):

def tok_embeddings(self):
"""Get embeddings."""
return self.params.get(self.tok_embeddings_key, None)
return self.transform(self.params.get(self.tok_embeddings_key, None), 'weight')

def norm_weight(self):
"""Get norm."""
return self.params.get(self.norm_weight_key, None)
return self.transform(self.params.get(self.norm_weight_key, None), 'weight')

def output_weight(self):
"""Get output."""
return self.params.get(self.output_weight_key, None)
return self.transform(self.params.get(self.output_weight_key, None), 'weight')

def _transform(self, x: torch.Tensor, kind: str):
return self.processor(x, kind)
Expand All @@ -74,7 +74,7 @@ def attn(self, i: int, kind: str):

def attn_norm(self, i: int):
"""Get attn norm for layer i."""
return self.params[f'{self.attn_layer_prefix}.{i}.input_layernorm.weight']
return self.transform(self.params[f'{self.attn_layer_prefix}.{i}.input_layernorm.weight'], 'weight')

def _ffn(self, i: int, kind: str):
"""Get ffn kind for layer i."""
Expand All @@ -94,7 +94,7 @@ def ffn(self, i: int, kind: str):

def ffn_norm(self, i: int):
"""Get ffn norm for layer i."""
return self.params[f'{self.attn_layer_prefix}.{i}.post_attention_layernorm.weight']
return self.transform(self.params[f'{self.attn_layer_prefix}.{i}.post_attention_layernorm.weight'], 'weight')


@INPUT_MODELS.register_module(name='llama')
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/turbomind/deploy/source_model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def moe_ffn_expert(self, e=None, i=None, kind=None):
return (*result, )

def moe_ffn_gate(self, i):
return self.params.get(f'model.layers.{i}.mlp.gate.weight')
return self.transform(self.params.get(f'model.layers.{i}.mlp.gate.weight'), 'weight')

def _ffn(self, i: int, kind: str):
"""Get ffn kind for layer i."""
Expand Down Expand Up @@ -172,7 +172,7 @@ def qk_norm(self, i: int):
result = []
for x in ['q', 'k']:
name = f'{self.attn_layer_prefix}.{i}.self_attn.{x}_norm.weight'
result.append(self.params.get(name))
result.append(self.transform(self.params.get(name), 'weight'))
return (*result, )


Expand All @@ -193,7 +193,7 @@ def qk_norm(self, i: int):
result = []
for x in ['q', 'k']:
name = f'{self.attn_layer_prefix}.{i}.self_attn.{x}_norm.weight'
result.append(self.params.get(name))
result.append(self.transform(self.params.get(name), 'weight'))
return (*result, )


Expand Down
6 changes: 4 additions & 2 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,19 @@ def _tofile(tensor, path):
elif len(self.tm_params) > 0:
tm_params = self.tm_params
weight_type = self.model_config.weight_type
assert weight_type in ['float16', 'bfloat16', 'int4']
assert weight_type in ['float16', 'bfloat16', 'int4', 'fp8']

# currently, the tensor type should in
# [torch.float, torch.half, torch.bfloat16, torch.int32]
torch_tensor = param.cuda().contiguous()
assert torch_tensor.dtype in [torch.int32, torch.float, torch.half, torch.bfloat16]
assert torch_tensor.dtype in [torch.int32, torch.float, torch.half, torch.bfloat16, torch.uint8]
if torch_tensor.dtype != torch.int32:
if weight_type in ['float16', 'int4']:
torch_tensor = torch_tensor.half()
elif weight_type == 'bfloat16':
torch_tensor = torch_tensor.bfloat16()
elif weight_type == 'fp8':
pass
else:
torch_tensor = torch_tensor.half()
for tm_tensor in tm_params[name]:
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ add_library(core STATIC
tensor.cu
module.cc)

target_link_libraries(core PUBLIC cuda_utils CUDA::cudart CUDA::cuda_driver)
target_link_libraries(core PUBLIC cuda_utils logger CUDA::cudart CUDA::cuda_driver)

set_property(TARGET core PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET core PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
Expand Down
22 changes: 21 additions & 1 deletion src/turbomind/core/cuda_data_type.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <cublas_v2.h>


#include <cuda.h>
#include <cuda_runtime.h>

#include <cublas_v2.h>

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
Expand Down Expand Up @@ -54,6 +58,22 @@ constexpr DataType from_cuda_dtype(cudaDataType type) {
}
}

constexpr CUtensorMapDataType to_CUtensorMap_dtype(DataType type) {
switch (type) {
case kFloat32:
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
case kFloat16:
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
case kBfloat16:
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case kFloat8_e4m3:
case kFloat8_e5m2:
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
default:
throw std::runtime_error("Not supported " + std::string{to_string(type)});
}
}

// clang-format on

} // namespace turbomind
4 changes: 2 additions & 2 deletions src/turbomind/core/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ constexpr const char* to_string(DataType type) {
case kFloat32: return "f32";
case kFloat64: return "f64";
case kBfloat16: return "bf16";
case kFloat8_e4m3: return "f8_e4m3";
case kFloat8_e5m2: return "f8_e5m2";
case kFloat8_e4m3: return "e4m3";
case kFloat8_e5m2: return "e5m2";
case kUint2: return "u2";
case kUint4: return "u4";
case kUint6: return "u8";
Expand Down
Loading
Loading