Skip to content

Add FP8 MoE for turbomind #3601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 167 commits into from
Jun 13, 2025
Merged
Changes from 1 commit
Commits
Show all changes
167 commits
Select commit Hold shift + click to select a range
404263d
low level abstraction
lzhangzz Mar 27, 2025
81bfa75
refactor
lzhangzz Apr 2, 2025
770b85d
eliminate template
lzhangzz Apr 7, 2025
e3a9619
remove unused
lzhangzz Apr 7, 2025
6b9a433
refactor bindings
lzhangzz Apr 7, 2025
613aeec
simplify lm head
lzhangzz Apr 7, 2025
e3fe34c
refactor weight
lzhangzz Apr 8, 2025
1e057d1
fix tp
lzhangzz Apr 8, 2025
40e9097
cublas
lzhangzz Apr 8, 2025
6fc9cc9
Merge remote-tracking branch 'origin/main' into core
lzhangzz Apr 9, 2025
ff3b5f7
refactor sampling
lzhangzz Apr 10, 2025
06ff641
remove unused
lzhangzz Apr 10, 2025
14a7f45
simplify
lzhangzz Apr 11, 2025
096155c
fix AWQ support
lzhangzz Apr 11, 2025
5fd35ae
fix moe
lzhangzz Apr 11, 2025
0c5ef46
fix nccl lm_head
lzhangzz Apr 11, 2025
c2020b2
fix
lzhangzz Apr 11, 2025
510675c
refactor data types
lzhangzz Apr 15, 2025
00b121e
skip legacy ut
lzhangzz Apr 15, 2025
88d17d4
simplify
lzhangzz Apr 15, 2025
699c24f
rename data types
lzhangzz Apr 15, 2025
3ffd070
refactor
lzhangzz Apr 15, 2025
eed6bfb
refactor runtime states
lzhangzz Apr 16, 2025
d2ec3af
fix msvc build
lzhangzz Apr 16, 2025
2529631
fix msvc build
lzhangzz Apr 16, 2025
1d77856
fix msvc build
lzhangzz Apr 16, 2025
6e728cf
fix msvc build
lzhangzz Apr 16, 2025
1b6a80d
fix msvc build
lzhangzz Apr 16, 2025
0d976d3
fix msvc build
lzhangzz Apr 16, 2025
18e7602
fix msvc build
lzhangzz Apr 16, 2025
7fec496
fix msvc build
lzhangzz Apr 16, 2025
69b1841
fix msvc build
lzhangzz Apr 16, 2025
c8bc36d
format
lzhangzz Apr 16, 2025
8161c0d
remove unused
lzhangzz Apr 16, 2025
7459992
fix msvc build
lzhangzz Apr 16, 2025
d38421f
fix msvc build
lzhangzz Apr 16, 2025
7d6ab03
fix msvc build
lzhangzz Apr 16, 2025
b214a0e
fix msvc build
lzhangzz Apr 16, 2025
3ab38ca
fix msvc build
lzhangzz Apr 16, 2025
f394ef0
fix msvc build
lzhangzz Apr 16, 2025
105f1cc
fix msvc build
lzhangzz Apr 16, 2025
5ccf30c
fix msvc build
lzhangzz Apr 16, 2025
b59620c
fix msvc build
lzhangzz Apr 16, 2025
a243da0
fix msvc build
lzhangzz Apr 16, 2025
42172d3
fix msvc build
lzhangzz Apr 16, 2025
4d9910a
fix msvc build
lzhangzz Apr 16, 2025
bf7c213
fix msvc build
lzhangzz Apr 16, 2025
8651edd
fix msvc build
lzhangzz Apr 16, 2025
4788b80
fix msvc build
lzhangzz Apr 16, 2025
529225d
fix msvc build
lzhangzz Apr 16, 2025
6fd5f72
fix msvc build
lzhangzz Apr 17, 2025
86d4e86
fix msvc build
lzhangzz Apr 17, 2025
8ea7e20
fix ut & msvc build
lzhangzz Apr 17, 2025
dcf9669
fix ut & msvc build
lzhangzz Apr 17, 2025
7f8974b
fix gcc build
lzhangzz Apr 17, 2025
646813b
fix lint & ut
lzhangzz Apr 17, 2025
98f4840
fix lint
lzhangzz Apr 17, 2025
ea07957
fetch Catch2 when building tests
lzhangzz Apr 17, 2025
5d69923
rewind msvc build
lzhangzz Apr 17, 2025
d0079b5
fix sampling
lzhangzz Apr 18, 2025
15b4007
fp8 round trip test
lzhangzz Apr 21, 2025
72fa3ee
pseudo quant test
lzhangzz Apr 22, 2025
c4de357
initial sm90 gemm kernel
lzhangzz Apr 24, 2025
26b270a
optimize smem desc
lzhangzz Apr 24, 2025
68ac168
multiple warp groups
lzhangzz Apr 24, 2025
4e573c5
optimize smem desc
lzhangzz Apr 24, 2025
4b801eb
flush TMA ops
lzhangzz Apr 24, 2025
97407ad
TMA epilogue
lzhangzz Apr 24, 2025
a97a3bb
clean-up
lzhangzz Apr 24, 2025
b26d7c8
launch config & fence operand
lzhangzz Apr 24, 2025
24520ee
tuning
lzhangzz Apr 24, 2025
eb7a50e
minor
lzhangzz Apr 24, 2025
c477918
TMA multicast
lzhangzz Apr 25, 2025
04257a6
pipeline
lzhangzz Apr 25, 2025
da5b22d
warp specialization
lzhangzz Apr 25, 2025
c728db2
persistent kernel
lzhangzz Apr 27, 2025
7e03761
better TMA multicast
lzhangzz Apr 28, 2025
4fb3799
initial fp8 gemm
lzhangzz Apr 29, 2025
27f1d29
optimize
lzhangzz Apr 29, 2025
876fbe7
optimize
lzhangzz Apr 30, 2025
316f2f8
cluster boundary check
lzhangzz Apr 30, 2025
fb9dd9a
slow
lzhangzz Apr 30, 2025
dc6f660
revert
lzhangzz Apr 30, 2025
fd515e1
v2
lzhangzz Apr 30, 2025
f0e4e4e
optimize
lzhangzz May 1, 2025
4bb1033
fix scaling
lzhangzz May 6, 2025
87ab9df
fix scaling
lzhangzz May 6, 2025
c1ec1de
optimize
lzhangzz May 7, 2025
45b8218
TMA store swizzle
lzhangzz May 7, 2025
26b6e21
better TMA store swizzle
lzhangzz May 8, 2025
49e6ac9
clean up
lzhangzz May 8, 2025
c2e3564
prefetch U
lzhangzz May 8, 2025
42571c2
prefetch V
lzhangzz May 8, 2025
108e2aa
fix V for multi wg
lzhangzz May 9, 2025
89fbbb7
fix U stride for TMA
lzhangzz May 9, 2025
464a2bc
qwen3 dense fp8
lzhangzz May 13, 2025
0653c65
fix tma multicast
lzhangzz May 13, 2025
c14fcc4
fix producer register count
lzhangzz May 13, 2025
dfbae20
decouple epilogue
lzhangzz May 13, 2025
731f2c5
warpspecialized pingpong
lzhangzz May 15, 2025
77ea397
fix multicast
lzhangzz May 15, 2025
2d991a1
cluster layout
lzhangzz May 15, 2025
86344c7
update
lzhangzz May 16, 2025
56e181b
optimize
lzhangzz May 16, 2025
0a422d3
larger tiles
lzhangzz May 16, 2025
a004cd7
multicast U
lzhangzz May 19, 2025
07c181d
rename
lzhangzz May 19, 2025
b18f336
v3
lzhangzz May 19, 2025
88018a6
optimize v3
lzhangzz May 19, 2025
5959297
tune
lzhangzz May 20, 2025
a3022a3
`tensormap.replace`
lzhangzz May 21, 2025
7eda414
refactor
lzhangzz May 22, 2025
2dc68c2
v4
lzhangzz May 23, 2025
5ccf1d8
init moe support
lzhangzz May 26, 2025
d98be0f
fix
lzhangzz May 27, 2025
2de6d70
fix v stride
lzhangzz May 27, 2025
5ef84ee
fix empty tile & unaligned U
lzhangzz May 27, 2025
7eb5f69
multicast schedule
lzhangzz May 27, 2025
bac2803
scheduling
lzhangzz May 28, 2025
062872d
fix scheduling
lzhangzz May 28, 2025
493ab50
tune
lzhangzz May 29, 2025
e3f3818
fix non-grouped gemm
lzhangzz May 29, 2025
e6119a3
tune group gemm
lzhangzz May 29, 2025
cfacea2
load weight by pointers
lzhangzz May 29, 2025
90ed985
fp8 moe
lzhangzz May 30, 2025
4a58024
Merge remote-tracking branch 'origin/main' into gemm3
lzhangzz May 30, 2025
036d67a
switch to https git
lzhangzz Jun 2, 2025
8474114
fix cutlass tag
lzhangzz Jun 2, 2025
e4b4c49
90 -> 90a
lzhangzz Jun 2, 2025
75eb1ed
fix missing headers
lzhangzz Jun 2, 2025
3493b2d
fix cuda-11
lzhangzz Jun 2, 2025
ce5ac45
guard sm90
lzhangzz Jun 2, 2025
863be32
update
lzhangzz Jun 2, 2025
11289e4
v5
lzhangzz Jun 2, 2025
d31e65d
update
lzhangzz Jun 2, 2025
2fa2ad5
update
lzhangzz Jun 2, 2025
19b93d2
v5
lzhangzz Jun 3, 2025
e1dd6e5
schedule
lzhangzz Jun 5, 2025
887f0db
update
lzhangzz Jun 5, 2025
fb551fe
optimize
lzhangzz Jun 5, 2025
dae5ec9
optimize
lzhangzz Jun 5, 2025
d2aeba0
fix multicast
lzhangzz Jun 5, 2025
a84da18
refactor
lzhangzz Jun 6, 2025
6ebfc5d
refactor
lzhangzz Jun 6, 2025
091abd4
fix performance regression
lzhangzz Jun 9, 2025
4ae99d6
v5
lzhangzz Jun 9, 2025
f376d47
dispatch cluster shape
lzhangzz Jun 9, 2025
e6782a0
optimize
lzhangzz Jun 9, 2025
8f5f9d7
fix sm count
lzhangzz Jun 9, 2025
a573447
guard CUDA version
lzhangzz Jun 9, 2025
b6ea7cd
guard CUDA version
lzhangzz Jun 9, 2025
56e6076
guard CUDA version
lzhangzz Jun 10, 2025
dbc9013
guard CUDA version
lzhangzz Jun 10, 2025
2dc4494
fix
lzhangzz Jun 10, 2025
4168387
fix CUDA version guard
lzhangzz Jun 10, 2025
28972e0
fix
lzhangzz Jun 10, 2025
9dbaa2f
fix MSVC build
lzhangzz Jun 10, 2025
0a0a7a7
fix
lzhangzz Jun 10, 2025
c9a742b
register all kernels
lzhangzz Jun 10, 2025
307202c
refactor
lzhangzz Jun 10, 2025
48b8035
optimize
lzhangzz Jun 10, 2025
67f0615
build with cuda-12.4
lzhangzz Jun 11, 2025
73b2dec
Merge remote-tracking branch 'origin/main' into gemm3
lzhangzz Jun 12, 2025
89ec814
fix lint
lzhangzz Jun 12, 2025
fdb7cfb
fix lint
lzhangzz Jun 12, 2025
3d3f03e
fix lint
lzhangzz Jun 12, 2025
d9586ea
disable debug log
lzhangzz Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor weight
  • Loading branch information
lzhangzz committed Apr 8, 2025
commit e3fe34c97b4de16332773fb93c9bcbc58691471a
1 change: 1 addition & 0 deletions src/turbomind/core/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ add_library(core STATIC
layout.cc
tensor.cc
tensor.cu
module.cc
typecvt.cc)
target_link_libraries(core PRIVATE CUDA::cudart CUDA::cuda_driver)
set_property(TARGET core PROPERTY POSITION_INDEPENDENT_CODE ON)
78 changes: 78 additions & 0 deletions src/turbomind/core/module.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@

#include "src/turbomind/core/module.h"
#include "src/turbomind/core/check.h"
#include <optional>

namespace turbomind::core {

Module::Module(): parent_{} {}

Module::~Module()
{
if (parent_) {
parent_->remove_module(*this);
parent_ = {};
}
}

void Module::register_module(std::string name, Module& module, std::optional<int> index)
{
module.parent_ = this;
if (index) {
name += ".";
name += std::to_string(*index);
}
// std::cout << "register Module " << name << " " << &module << ", parent " << this << "\n";
modules_.emplace_back(std::move(name), &module);
}

void Module::register_parameter(std::string name, Tensor& param)
{
// std::cout << "register Parameter " << name << " " << &param << " " << param.layout() << "\n";
params_.emplace_back(std::move(name), &param);
}

void Module::remove_module(Module& module)
{
for (auto it = modules_.begin(); it != modules_.end(); ++it) {
if (it->second == &module) {
// std::cout << "erase " << it->first << " " << &module << " from " << this << "\n";
modules_.erase(it);
return;
}
}
TM_CHECK(0) << "module " << &module << " not found";
}

void Module::remove_parameter(Tensor& param)
{
for (auto it = params_.begin(); it != params_.end(); ++it) {
if (it->second == &param) {
params_.erase(it);
return;
}
}
TM_CHECK(0) << "param " << &param << " not found";
}

TensorMap Module::get_parameters() const
{
TensorMap m;
get_parameters_impl({}, m);
return m;
}

void Module::get_parameters_impl(std::string prefix, TensorMap& m) const
{
if (!prefix.empty()) {
prefix += ".";
}
for (const auto& [k, v] : params_) {
m.emplace(prefix + k, *v);
}
for (const auto& [k, v] : modules_) {
v->get_parameters_impl(prefix + k, m);
}
}

} // namespace turbomind::core
36 changes: 36 additions & 0 deletions src/turbomind/core/module.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

#include "src/turbomind/core/tensor.h"

namespace turbomind::core {

class Module {
public:
virtual ~Module();

Module();

Module(const Module&) = delete;
Module& operator=(const Module&) = delete;

Module(Module&&) noexcept = delete;
Module& operator=(Module&&) noexcept = delete;

void register_module(std::string name, Module& module, std::optional<int> index = {});
void register_parameter(std::string name, Tensor& param);

void remove_module(Module& module);
void remove_parameter(Tensor& param);

TensorMap get_parameters() const;

private:
void get_parameters_impl(std::string prefix, TensorMap& m) const;

protected:
Module* parent_;

std::vector<std::pair<std::string, Module*>> modules_;
std::vector<std::pair<std::string, Tensor*>> params_;
};

} // namespace turbomind::core
9 changes: 5 additions & 4 deletions src/turbomind/kernels/norm/rms_norm.cu
Original file line number Diff line number Diff line change
@@ -84,11 +84,12 @@ __global__ void RMSNorm(T* dst,

} // namespace kernel

void invokeRMSNorm(core::Tensor& out, const core::Tensor& x, const void* w, float eps, cudaStream_t st)
void invokeRMSNorm(core::Tensor& out, const core::Tensor& x, const core::Tensor& w, float eps, cudaStream_t st)
{
TM_CHECK(x.ndim() == 2);
TM_CHECK(out.shape() == x.shape());
TM_CHECK(out.dtype() == x.dtype());
TM_CHECK(w.dtype() == x.dtype() && w.shape(-1) == x.shape(-1));

if (x.size() == 0) {
return;
@@ -108,7 +109,7 @@ void invokeRMSNorm(core::Tensor& out, const core::Tensor& x, const void* w, floa
out.stride(0),
(const T*)x.raw_data(),
x.stride(0),
(const T*)w,
(const T*)w.raw_data(),
dim,
num,
eps,
@@ -227,7 +228,7 @@ void invokeQkRMSNorm(void* data,
}
}

void invokeRMSNormQK(core::Tensor& x, const void* w, float eps, cudaStream_t st)
void invokeRMSNormQK(core::Tensor& x, const core::Tensor& w, float eps, cudaStream_t st)
{
TM_CHECK(x.ndim() == 3);

@@ -253,7 +254,7 @@ void invokeRMSNormQK(core::Tensor& x, const void* w, float eps, cudaStream_t st)
const int grid_dim = cdiv(threads, block_dim);

kernel::RMSNormQK<T, float, vec_size, max_dim><<<grid_dim, block_dim, 0, st>>>(
(T*)data, stride, (const T*)w, head_dim, head_num, token_num, eps, 1.f / head_dim);
(T*)data, stride, (const T*)w.raw_data(), head_dim, head_num, token_num, eps, 1.f / head_dim);
};

constexpr constant<128> max_dim{};
14 changes: 2 additions & 12 deletions src/turbomind/kernels/norm/rms_norm.h
Original file line number Diff line number Diff line change
@@ -8,19 +8,9 @@

namespace turbomind {

void invokeRMSNorm(core::Tensor& out, const core::Tensor& x, const void* w, float eps, cudaStream_t st);
void invokeRMSNorm(core::Tensor& out, const core::Tensor& x, const core::Tensor& w, float eps, cudaStream_t st);

inline void invokeRMSNorm(core::Tensor& out, const core::Tensor& x, const core::Buffer& w, float eps, cudaStream_t st)
{
return invokeRMSNorm(out, x, w.raw_data(), eps, st);
}

void invokeRMSNormQK(core::Tensor& x, const void* w, float eps, cudaStream_t st);

inline void invokeRMSNormQK(core::Tensor& x, const core::Buffer& w, float eps, cudaStream_t st)
{
return invokeRMSNormQK(x, w.raw_data(), eps, st);
}
void invokeRMSNormQK(core::Tensor& x, const core::Tensor& w, float eps, cudaStream_t st);

template<class T>
void invokeBiasResidualRMSNorm(
1 change: 1 addition & 0 deletions src/turbomind/models/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ add_library(Llama STATIC
BlockTrie.cc
SequenceManager.cc
LlamaWeight.cc
LlamaDenseWeight.cc
LlamaDecoderLayerWeight.cc
LlamaFfnLayer.cc
moe_ffn_layer.cc
285 changes: 75 additions & 210 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
Original file line number Diff line number Diff line change
@@ -74,91 +74,53 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(DataType data_type,
mlp_tp_size_(engine.mlp_tp_size),
mlp_tp_rank_(engine.mlp_tp_rank)
{
self_attn_weights = LlamaAttentionWeight{hidden_units_,
size_per_head_,
head_num_,
kv_head_num_,
model.mla,
attn_bias_,
model.qk_norm,
attn_tp_size_,
data_type_,
weight_type_,
model.group_size};

ffn_weights = LlamaFfnWeight{
hidden_units_,
inter_size_,
mlp_tp_size_,
data_type_,
weight_type_,
model.group_size,
weight_type_ == TYPE_UINT4 && is_fuse_silu_act(),
};

moe_weights = MoeFfnWeight{layer_id,
moe_param,
hidden_units_,
data_type_,
weight_type_,
model.group_size,
mlp_tp_size_,
is_fuse_silu_act()};

if (lora_param.policy == LoraPolicy::kPlora) {
std::vector<std::string> keys = {
"attention.w_qkv", "attention.wo", "feed_forward.w1", "feed_forward.w2", "feed_forward.w3"};
std::vector<LlamaDenseWeight*> weights = {&self_attn_weights.qkv,
&self_attn_weights.output,
&ffn_weights.gating,
&ffn_weights.output,
&ffn_weights.intermediate};
for (int i = 0; i < keys.size(); i++) {
const auto& name = keys[i];
auto& weight = *weights[i];
int rank = lora_param.r;
float scale = lora_param.scale;
std::string full_name = "layers." + std::to_string(layer_id) + "." + name;

for (const auto& [re, pr] : lora_param.rank_pattern) {
if (std::regex_search(full_name, pr.first)) {
rank = pr.second;
TM_LOG_DEBUG("find rank, pattern=%s, name=%s, value=%d", re.c_str(), full_name.c_str(), rank);
break;
}
}
for (const auto& [re, pr] : lora_param.scale_pattern) {
if (std::regex_search(full_name, pr.first)) {
scale = pr.second;
TM_LOG_DEBUG("find scale pattern=%s, name=%s, value=%f", re.c_str(), full_name.c_str(), scale);
break;
}
}
if (rank) {
weight.lora.r = rank;
weight.lora.scale = scale;
weight.lora.policy = lora_param.policy;
}
}
}

fused_up_and_gate_ = ffn_weights.gating.lora.policy != LoraPolicy::kPlora;
}

void LlamaDecoderLayerWeight::malloc()
{
self_attn_norm = core::Buffer{hidden_units_, data_type_, MEMORY_GPU};
ffn_norm = core::Buffer{hidden_units_, data_type_, MEMORY_GPU};

self_attn_weights.malloc();
self_attn_weights.reset(new LlamaAttentionWeight{hidden_units_,
size_per_head_,
head_num_,
kv_head_num_,
model.mla,
attn_bias_,
model.qk_norm,
attn_tp_size_,
attn_tp_rank_,
data_type_,
weight_type_,
model.group_size});
register_module("attention", *self_attn_weights);

if (inter_size_) {
ffn_weights.malloc();
ffn_weights.reset(new LlamaFfnWeight{
hidden_units_,
inter_size_,
mlp_tp_size_,
mlp_tp_rank_,
data_type_,
weight_type_,
model.group_size,
weight_type_ == TYPE_UINT4 && is_fuse_silu_act(),
});
register_module("feed_forward", *ffn_weights);
}

if (!moe_weights.experts.empty()) {
moe_weights.malloc();
if (layer_id < moe_param.expert_num.size() && moe_param.expert_num[layer_id]) {
moe_weights.reset(new MoeFfnWeight{layer_id,
moe_param,
hidden_units_,
data_type_,
weight_type_,
model.group_size,
mlp_tp_size_,
mlp_tp_rank_,
is_fuse_silu_act()});
register_module("moe_ffn", *moe_weights);
}

fused_up_and_gate_ = ffn_weights->gating.lora.policy != LoraPolicy::kPlora;

self_attn_norm = core::Tensor{{hidden_units_}, data_type_, MEMORY_GPU};
ffn_norm = core::Tensor{{hidden_units_}, data_type_, MEMORY_GPU};
register_parameter("attention_norm.weight", self_attn_norm);
register_parameter("ffn_norm.weight", ffn_norm);
}

size_t LlamaDecoderLayerWeight::workspace_size() const noexcept
@@ -169,128 +131,23 @@ size_t LlamaDecoderLayerWeight::workspace_size() const noexcept

size_t size = 0;

size = std::max(size, get_size(self_attn_weights.qkv));
size = std::max(size, get_size(self_attn_weights.output));
size = std::max(size, get_size(ffn_weights.gating));
size = std::max(size, get_size(ffn_weights.fused_gating_intermediate));
size = std::max(size, get_size(self_attn_weights->qkv));
size = std::max(size, get_size(self_attn_weights->output));
size = std::max(size, get_size(ffn_weights->gating));
size = std::max(size, get_size(ffn_weights->fused_gating_intermediate));

for (const auto& e : moe_weights.experts) {
size = std::max(size, get_size(e.gating));
size = std::max(size, get_size(e.fused_gating_intermediate));
if (moe_weights) {
for (const auto& e : moe_weights->experts) {
size = std::max(size, get_size(e->gating));
size = std::max(size, get_size(e->fused_gating_intermediate));
}
}

return size * sizeof(uint16_t);
}

template<typename FirstArg, typename... Args>
std::string concat(FirstArg&& first, Args&&... args)
{
std::stringstream stream;
stream << first;
((stream << "." << args), ...);
return stream.str();
}

void getWeightTensor(LlamaDenseWeight& dense, bool bias, const std::string& prefix, core::TensorMap& output)
{
auto get_name = [=](const std::string& name) { return concat(prefix, name); };

TM_CHECK_EQ(bias, bool(dense.bias));
if (bias) {
output.emplace(get_name("bias"), dense.bias);
}

const size_t bit_size = core::get_byte_size(dense.weight_type, 8);
if (bit_size >= 16) {
output.emplace(get_name("weight"), dense.weight);
}
else {
output.emplace(get_name("qweight"), dense.weight);
output.emplace(get_name("scales"), dense.scales);
output.emplace(get_name("zeros"), dense.zeros);
}
}

void LlamaDecoderLayerWeight::free()
{
self_attn_norm = {};
ffn_norm = {};

self_attn_weights.free();

if (inter_size_) {
ffn_weights.free();
}

if (!moe_weights.experts.empty()) {
moe_weights.free();
}
}

LlamaDecoderLayerWeight::~LlamaDecoderLayerWeight() = default;

void getMLATensor(LlamaAttentionWeight& w, const std::string& p, core::TensorMap& m, int tp_rank)
{
if (w.q_proj.output_dim) {
getWeightTensor(w.q_proj, false, concat(p, "attention.q_proj", tp_rank), m);
}
else {
getWeightTensor(w.q_a_proj, false, concat(p, "attention.q_a_proj"), m);
getWeightTensor(w.q_b_proj, false, concat(p, "attention.q_b_proj", tp_rank), m);
m.emplace(concat(p, "attention.q_a_layernorm"), w.q_a_layernorm);
}
getWeightTensor(w.kv_a_proj, false, concat(p, "attention.kv_a_proj"), m);
getWeightTensor(w.kv_b_proj, false, concat(p, "attention.kv_b_proj", tp_rank), m);
m.emplace(concat(p, "attention.kv_a_layernorm"), w.kv_a_layernorm);
}

core::TensorMap LlamaDecoderLayerWeight::getParams(std::string prefix)
{
core::TensorMap output;

output.emplace(concat(prefix, "attention_norm.weight"), self_attn_norm);
output.emplace(concat(prefix, "ffn_norm.weight"), ffn_norm);

auto get_attn = [=](std::string_view name) { return concat(prefix, name, attn_tp_rank_); };

if (self_attn_weights.qkv.output_dim) {
getWeightTensor(self_attn_weights.qkv, attn_bias_, get_attn("attention.w_qkv"), output);

if (self_attn_weights.qk_norm) {
output.emplace(concat(prefix, "attention.q_norm"), self_attn_weights.q_a_layernorm);
output.emplace(concat(prefix, "attention.k_norm"), self_attn_weights.kv_a_layernorm);
}
}
else {
getMLATensor(self_attn_weights, prefix, output, attn_tp_rank_);
}
getWeightTensor(self_attn_weights.output, attn_bias_, get_attn("attention.wo"), output);

auto get_mlp = [=](std::string_view name) { return concat(prefix, name, mlp_tp_rank_); };

if (inter_size_) {
getWeightTensor(ffn_weights.gating, false, get_mlp("feed_forward.w1"), output);
getWeightTensor(ffn_weights.intermediate, false, get_mlp("feed_forward.w3"), output);
getWeightTensor(ffn_weights.output, false, get_mlp("feed_forward.w2"), output);
}

if (!moe_weights.experts.empty()) {
output.emplace(concat(prefix, "moe_ffn.gate.weight"), moe_weights.gate.weight);
auto& experts = moe_weights.experts;
for (size_t i = 0; i < experts.size(); ++i) {
const std::string name = "moe_ffn.experts." + std::to_string(i);
getWeightTensor(experts[i].gating, false, get_mlp(concat(name, "w1")), output);
getWeightTensor(experts[i].intermediate, false, get_mlp(concat(name, "w3")), output);
getWeightTensor(experts[i].output, false, get_mlp(concat(name, "w2")), output);
}
if (moe_weights.shared_gate.weight) {
output.emplace(concat(prefix, "moe_ffn.shared_gate.weight"), moe_weights.shared_gate.weight);
}
}

return output;
}

static void
convert_u4(LlamaDenseWeight& dense, bool is_fused_moe, void* workspace, size_t size, bool use_simt, cudaStream_t st)
{
@@ -543,14 +400,19 @@ void LlamaDecoderLayerWeight::prepare(void* workspace, size_t size, const cudaDe
{
const bool is_16xx = is_16xx_series(prop.name);

convert(self_attn_weights.qkv, false, data_type_, workspace, size, is_16xx, st);
convert(self_attn_weights.output, false, data_type_, workspace, size, is_16xx, st);
convert(self_attn_weights->qkv, false, data_type_, workspace, size, is_16xx, st);
convert(self_attn_weights->output, false, data_type_, workspace, size, is_16xx, st);

auto process_ffn = [&](LlamaFfnWeight& ffn, bool is_fused_moe) {
if (fused_up_and_gate_) {
auto& fused_up_and_gate = ffn.fused_gating_intermediate;

fused_up_and_gate.malloc(st);
fused_up_and_gate.emplace(ffn.gating.input_dim,
ffn.gating.output_dim * 2,
data_type_,
false,
weight_type_,
ffn.gating.group_size);

if (ffn.is_fused_silu) {
interleave(fused_up_and_gate, ffn.gating, ffn.intermediate, data_type_, workspace, size, st);
@@ -561,8 +423,8 @@ void LlamaDecoderLayerWeight::prepare(void* workspace, size_t size, const cudaDe

convert(ffn.fused_gating_intermediate, is_fused_moe, data_type_, workspace, size, is_16xx, st);

ffn.gating.free();
ffn.intermediate.free();
ffn.gating = {};
ffn.intermediate = {};
}
else {
convert(ffn.gating, is_fused_moe, data_type_, workspace, size, is_16xx, st);
@@ -574,37 +436,39 @@ void LlamaDecoderLayerWeight::prepare(void* workspace, size_t size, const cudaDe

if (inter_size_) {
// std::cerr << "process FFN\n";
process_ffn(ffn_weights, false);
process_ffn(*ffn_weights, false);
}

if (!moe_weights.experts.empty()) {
if (moe_weights) {
// std::cerr << "process MoE\n";
std::vector<std::pair<void*, int>> fused_ptrs;
std::vector<std::pair<void*, int>> output_ptrs;
std::vector<std::pair<void*, int>> fused_param_ptrs;
std::vector<std::pair<void*, int>> output_param_ptrs;

for (auto& e : moe_weights.experts) {
for (auto& e : moe_weights->experts) {

process_ffn(e, moe_weights.method == MoeParam::kFused);
process_ffn(*e, moe_weights->method == MoeParam::kFused);

auto& fused = e.fused_gating_intermediate;
auto& output = e.output;
auto& fused = e->fused_gating_intermediate;
auto& output = e->output;

fused_ptrs.push_back({fused.weight.raw_data(), fused.k_desc.ld});
output_ptrs.push_back({output.weight.raw_data(), output.k_desc.ld});

if (e.fused_gating_intermediate.scales_zeros) {
if (e->fused_gating_intermediate.scales_zeros) {
fused_param_ptrs.emplace_back(fused.scales_zeros.raw_data(), fused.q_desc.ld);
output_param_ptrs.emplace_back(output.scales_zeros.raw_data(), output.q_desc.ld);
}
}

#if 0
// Note: This assumes all experts has the same shape
moe_weights.block = moe_weights.experts.at(0);
auto& b_ = moe_weights->block;
auto& e_ = *moe_weights->experts.at(0);


auto& fused = moe_weights.block.fused_gating_intermediate;
auto& output = moe_weights.block.output;
auto& fused = moe_weights->block.fused_gating_intermediate;
auto& output = moe_weights->block.output;

const auto weight_type = fused.weight_type;

@@ -625,6 +489,7 @@ void LlamaDecoderLayerWeight::prepare(void* workspace, size_t size, const cudaDe

fused.q_desc.ld = output.q_desc.ld = 0;
fused.q_desc.num = output.q_desc.num = moe_weights.experts.size();
#endif
}
}

18 changes: 6 additions & 12 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.h
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@

namespace turbomind {

struct LlamaDecoderLayerWeight {
struct LlamaDecoderLayerWeight: core::Module {
public:
LlamaDecoderLayerWeight() = delete;

@@ -41,23 +41,17 @@ struct LlamaDecoderLayerWeight {
LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight&) = delete;
LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight&) = delete;

core::TensorMap getParams(std::string prefix);

void prepare(void* workspace, size_t size, const cudaDeviceProp& prop, cudaStream_t st);

size_t workspace_size() const noexcept;

void malloc();

void free();

core::Buffer self_attn_norm;
core::Buffer ffn_norm;
core::Tensor self_attn_norm;
core::Tensor ffn_norm;

LlamaAttentionWeight self_attn_weights{};
std::unique_ptr<LlamaAttentionWeight> self_attn_weights;

LlamaFfnWeight ffn_weights{};
MoeFfnWeight moe_weights{};
std::unique_ptr<LlamaFfnWeight> ffn_weights;
std::unique_ptr<MoeFfnWeight> moe_weights;

private:
int head_num_;
158 changes: 158 additions & 0 deletions src/turbomind/models/llama/LlamaDenseWeight.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#include "src/turbomind/models/llama/LlamaDenseWeight.h"

namespace turbomind {

void LlamaDenseWeight::emplace(
int input_dim, int output_dim, DataType data_type, bool bias, DataType weight_type, int group_size)
{
this->data_type = data_type;
this->weight_type = weight_type;
this->input_dim = input_dim;
this->output_dim = output_dim;
this->group_size = group_size;

const auto wbits = core::get_byte_size(weight_type, 8);

weight = core::Tensor({input_dim, output_dim}, weight_type, MEMORY_GPU);
register_parameter(wbits < 16 ? "qweight" : "weight", weight);

if (bias) {
this->bias = core::Tensor{{output_dim}, data_type, MEMORY_GPU};
register_parameter("bias", this->bias);
}

if (wbits < 16) {
TM_CHECK(input_dim % group_size == 0) << input_dim << " " << group_size;
scales = core::Tensor{{input_dim / group_size, output_dim}, data_type, MEMORY_GPU};
zeros = core::Tensor{{input_dim / group_size, output_dim}, data_type, MEMORY_GPU};
register_parameter("scales", scales);
register_parameter("zeros", zeros);
}
}

LlamaAttentionWeight::LlamaAttentionWeight(int hidden_dim,
int head_dim,
int head_num,
int kv_head_num,
MLAParam mla,
bool bias,
bool qk_norm,
int tp_size,
int tp_rank,
DataType data_type,
DataType weight_type,
int group_size)
{
if (mla.kv_lora_rank == 0) {
qkv.emplace(
hidden_dim, (head_num + 2 * kv_head_num) * head_dim / tp_size, data_type, bias, weight_type, group_size);
register_module("w_qkv", qkv, tp_rank);
if (qk_norm) {
q_a_layernorm = core::Tensor{{head_dim}, data_type, MEMORY_GPU};
kv_a_layernorm = core::Tensor{{head_dim}, data_type, MEMORY_GPU};
register_parameter("q_norm", q_a_layernorm);
register_parameter("k_norm", kv_a_layernorm);
}
}
else {
const int qk_nope_dim = head_dim - mla.qk_rope_dim;
if (mla.q_lora_rank) {
q_a_proj.emplace(hidden_dim, mla.q_lora_rank, data_type, false, weight_type, group_size);
q_b_proj.emplace(mla.q_lora_rank, head_num * head_dim / tp_size, data_type, false, weight_type, group_size);
q_a_layernorm = core::Tensor{{q_b_proj.input_dim}, data_type, MEMORY_GPU};
register_module("q_a_proj", q_a_proj);
register_module("q_b_proj", q_b_proj, tp_rank);
register_parameter("q_a_layernorm", q_a_layernorm);
}
else {
q_proj.emplace(hidden_dim, head_num * head_dim / tp_size, data_type, false, weight_type, group_size);
register_module("q_proj", q_proj, tp_rank);
}
kv_a_proj.emplace(hidden_dim, mla.kv_lora_rank + mla.qk_rope_dim, data_type, false, weight_type, group_size);
kv_b_proj.emplace(mla.kv_lora_rank,
head_num * (qk_nope_dim + mla.v_head_dim) / tp_size,
data_type,
false,
weight_type,
group_size);

kv_a_layernorm = core::Tensor{{kv_b_proj.input_dim}, data_type, MEMORY_GPU};
register_module("kv_a_proj", kv_a_proj);
register_module("kv_b_proj", kv_b_proj, tp_rank);
register_parameter("kv_a_layernorm", kv_a_layernorm);
}
output.emplace((head_num * head_dim) / tp_size, hidden_dim, data_type, bias, weight_type, group_size);
register_module("wo", output, tp_rank);
}

LlamaFfnWeight::LlamaFfnWeight(int hidden_dim,
int inter_size,
int tp_size,
int tp_rank,
DataType data_type,
DataType weight_type,
int group_size,
bool fuse_silu_act)
{
TM_CHECK(inter_size % tp_size == 0) << inter_size << " " << tp_size;

inter_size /= tp_size;

this->inter_size = inter_size;

gating.emplace(hidden_dim, inter_size, data_type, false, weight_type, group_size);

intermediate.emplace(hidden_dim, inter_size, data_type, false, weight_type, group_size);

// fused_gating_intermediate = {hidden_dim, inter_size * 2, data_type, weight_type, group_size};
is_fused_silu = fuse_silu_act;

output.emplace(inter_size, hidden_dim, data_type, false, weight_type, group_size);

register_module("w1", gating, tp_rank);
register_module("w3", intermediate, tp_rank);
register_module("w2", output, tp_rank);
}

MoeFfnWeight::MoeFfnWeight(int layer_id,
const MoeParam& param,
int hidden_dim,
DataType data_type,
DataType weight_type,
int group_size,
int tp_size,
int tp_rank,
bool fuse_silu_act)
{
if ((int)param.expert_num.size() <= layer_id) {
return;
}

const int expert_num = param.expert_num[layer_id];

if (expert_num == 0) {
return;
}

// printf("%d %d %d\n", (int)hidden_dim, (int)param.inter_size, (int)expert_num);

gate.emplace(hidden_dim, expert_num, data_type, false, data_type, 1);
register_module("gate", gate);

method = param.method;
fuse_silu_act = fuse_silu_act && method == MoeParam::kFused;

experts.reserve(expert_num);
for (int i = 0; i < expert_num; ++i) {
experts.emplace_back(new LlamaFfnWeight{
hidden_dim, param.inter_size, tp_size, tp_rank, data_type, weight_type, group_size, fuse_silu_act});
register_module("experts", *experts.back(), i);
}

if (param.shared_gate) {
shared_gate.emplace(hidden_dim, 1, data_type, false, data_type, 1);
register_module("shared_gate", shared_gate);
}
}

} // namespace turbomind
241 changes: 36 additions & 205 deletions src/turbomind/models/llama/LlamaDenseWeight.h
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
#pragma once

#include "src/turbomind/core/buffer.h"
#include "src/turbomind/core/module.h"
#include "src/turbomind/core/tensor.h"

#include "src/turbomind/kernels/gemm/types.h"
@@ -47,7 +48,23 @@ struct LoraWeight {
void* b;
};

struct LlamaDenseWeight {
struct LlamaDenseWeight: public core::Module {

LlamaDenseWeight(): data_type{}, weight_type{}, lora{}, k_desc{}, q_desc{} {}

void emplace(int input_dim, int output_dim, DataType data_type, bool bias, DataType weight_type, int group_size);

LlamaDenseWeight& operator=(std::nullptr_t)
{
this->~LlamaDenseWeight();
new (this) LlamaDenseWeight{};
return *this;
}

operator bool() const noexcept
{
return static_cast<bool>(weight);
}

int input_dim = 0;
int output_dim = 0;
@@ -57,7 +74,7 @@ struct LlamaDenseWeight {
DataType weight_type;

core::Tensor weight;
core::Buffer bias;
core::Tensor bias;

core::Tensor scales;
core::Tensor zeros;
@@ -68,49 +85,9 @@ struct LlamaDenseWeight {

gemm::MatrixLayout k_desc;
gemm::MatrixLayout q_desc;

LlamaDenseWeight(): data_type{}, weight_type{}, lora{}, k_desc{}, q_desc{} {}

LlamaDenseWeight(int input_dim, int output_dim, DataType data_type, DataType weight_type, int group_size):
LlamaDenseWeight{}
{
this->data_type = data_type;
this->weight_type = weight_type;
this->input_dim = input_dim;
this->output_dim = output_dim;
this->group_size = group_size;
}

explicit operator bool() const noexcept
{
return static_cast<bool>(weight);
}

void malloc(bool with_bias = false)
{
if (with_bias) {
bias = core::Buffer{output_dim, data_type, MEMORY_GPU};
}

weight = core::Tensor({input_dim, output_dim}, weight_type, MEMORY_GPU);

if (auto wbits = core::get_byte_size(weight_type, 8); wbits <= 8) {
TM_CHECK_EQ(input_dim % group_size, 0);
scales = core::Tensor{{input_dim / group_size, output_dim}, data_type, MEMORY_GPU};
zeros = core::Tensor{{input_dim / group_size, output_dim}, data_type, MEMORY_GPU};
}
}

void free()
{
bias = {};
weight = {};
scales = {};
zeros = {};
}
};

struct LlamaAttentionWeight {
struct LlamaAttentionWeight: public core::Module {

LlamaAttentionWeight() = default;

@@ -121,80 +98,11 @@ struct LlamaAttentionWeight {
MLAParam mla,
bool bias,
bool qk_norm,
int tp,
int tp_size,
int tp_rank,
DataType data_type,
DataType weight_type,
int group_size)
{
this->bias = bias;
this->head_dim = head_dim;
this->qk_norm = qk_norm;
this->data_type = data_type;
this->weight_type = weight_type;

if (mla.kv_lora_rank == 0) {
qkv = {hidden_dim, (head_num + 2 * kv_head_num) * head_dim / tp, data_type, weight_type, group_size};
}
else {
const int qk_nope_dim = head_dim - mla.qk_rope_dim;
if (mla.q_lora_rank) {
q_a_proj = {hidden_dim, mla.q_lora_rank, data_type, weight_type, group_size};
q_b_proj = {mla.q_lora_rank, head_num * head_dim / tp, data_type, weight_type, group_size};
}
else {
q_proj = {hidden_dim, head_num * head_dim / tp, data_type, weight_type, group_size};
}
kv_a_proj = {hidden_dim, mla.kv_lora_rank + mla.qk_rope_dim, data_type, weight_type, group_size};
kv_b_proj = {
mla.kv_lora_rank, head_num * (qk_nope_dim + mla.v_head_dim) / tp, data_type, weight_type, group_size};
}
output = {(head_num * head_dim) / tp, hidden_dim, data_type, weight_type, group_size};
}

void malloc()
{
if (qkv.output_dim) {
qkv.malloc(bias);
if (qk_norm) {
q_a_layernorm = core::Buffer{head_dim, data_type, MEMORY_GPU};
kv_a_layernorm = core::Buffer{head_dim, data_type, MEMORY_GPU};
}
}
else { // MLA
if (q_proj.output_dim) {
q_proj.malloc();
}
else {
q_a_proj.malloc();
q_b_proj.malloc();
q_a_layernorm = core::Buffer{q_b_proj.input_dim, data_type, MEMORY_GPU};
}
kv_a_proj.malloc();
kv_b_proj.malloc();
kv_a_layernorm = core::Buffer{kv_b_proj.input_dim, data_type, MEMORY_GPU};
}
output.malloc(bias);
}

void free()
{
qkv.free();
q_proj.free();
q_a_proj.free();
q_b_proj.free();
kv_a_proj.free();
kv_b_proj.free();
output.free();
q_a_layernorm = {};
kv_a_layernorm = {};
}

int head_dim{};
bool bias{};
bool qk_norm{};

DataType data_type{};
DataType weight_type{};
int group_size);

LlamaDenseWeight qkv;
LlamaDenseWeight output;
@@ -205,49 +113,22 @@ struct LlamaAttentionWeight {
LlamaDenseWeight kv_a_proj;
LlamaDenseWeight kv_b_proj;

core::Buffer q_a_layernorm;
core::Buffer kv_a_layernorm;
core::Tensor q_a_layernorm;
core::Tensor kv_a_layernorm;
};

struct LlamaFfnWeight {
struct LlamaFfnWeight: core::Module {

LlamaFfnWeight() = default;

LlamaFfnWeight(int hidden_dim,
int inter_size,
int tp,
int tp_size,
int tp_rank,
DataType data_type,
DataType weight_type,
int group_size,
bool fuse_silu_act)
{
TM_CHECK_EQ(inter_size % tp, 0);

this->inter_size = inter_size;

gating = {hidden_dim, inter_size, data_type, weight_type, group_size};
intermediate = {hidden_dim, inter_size, data_type, weight_type, group_size};

fused_gating_intermediate = {hidden_dim, inter_size * 2, data_type, weight_type, group_size};
is_fused_silu = fuse_silu_act;

output = {inter_size, hidden_dim, data_type, weight_type, group_size};
}

void malloc()
{
gating.malloc();
intermediate.malloc();
output.malloc();
}

void free()
{
gating.free();
intermediate.free();
output.free();
fused_gating_intermediate.free();
}
bool fuse_silu_act);

LlamaDenseWeight gating;
LlamaDenseWeight intermediate;
@@ -258,7 +139,7 @@ struct LlamaFfnWeight {
bool is_fused_silu{};
};

struct MoeFfnWeight {
struct MoeFfnWeight: core::Module {

MoeFfnWeight() = default;

@@ -268,65 +149,15 @@ struct MoeFfnWeight {
DataType data_type,
DataType weight_type,
int group_size,
int tp,
bool fuse_silu_act)
{

if ((int)param.expert_num.size() <= layer_id) {
return;
}

const int expert_num = param.expert_num[layer_id];

if (expert_num == 0) {
return;
}

// printf("%d %d %d\n", (int)hidden_dim, (int)param.inter_size, (int)expert_num);

gate = {hidden_dim, expert_num, data_type, data_type, 1};

experts.resize(expert_num);

method = param.method;
fuse_silu_act = fuse_silu_act && method == MoeParam::kFused;

for (auto& e : experts) {
// inter size is divided by tp in `FfnWeight`
e = LlamaFfnWeight{hidden_dim, param.inter_size, tp, data_type, weight_type, group_size, fuse_silu_act};
}

if (param.shared_gate) {
shared_gate = {hidden_dim, 1, data_type, data_type, 1};
}
}

void malloc()
{
gate.malloc();
if (shared_gate.output_dim) {
shared_gate.malloc();
}
for (auto& e : experts) {
e.malloc();
}
}

void free()
{
gate.free();
shared_gate.free();
for (auto& e : experts) {
e.free();
}
block.free();
}

LlamaDenseWeight gate;
std::vector<LlamaFfnWeight> experts;
int tp_size,
int tp_rank,
bool fuse_silu_act);

LlamaDenseWeight gate;
LlamaDenseWeight shared_gate;

std::vector<std::unique_ptr<LlamaFfnWeight>> experts;

// reference into `experts`
LlamaFfnWeight block;

46 changes: 12 additions & 34 deletions src/turbomind/models/llama/LlamaWeight.cc
Original file line number Diff line number Diff line change
@@ -62,23 +62,23 @@ LlamaWeight::LlamaWeight(DataType data_type,

core::ContextGuard guard = context();

decoder_layer_weights.reserve(num_layer_);
for (unsigned l = 0; l < num_layer_; ++l) {
decoder_layer_weights.emplace_back(
new LlamaDecoderLayerWeight(data_type, l, model, engine_param, lora_param, moe_param));
decoder_layer_weights.back()->malloc();
}

TM_CHECK_EQ(vocab_size_padded_ % tp_size_, 0);
TM_CHECK_EQ(hidden_units_ % tp_size_, 0);

pre_decoder_embedding = LlamaDenseWeight{embedding_size_, hidden_units_ / tp_size_, data_type, data_type, 1};
pre_decoder_embedding.malloc();
pre_decoder_embedding.emplace(embedding_size_, hidden_units_ / tp_size_, data_type, false, data_type, 1);
post_decoder_embedding.emplace(hidden_units_, vocab_size_padded_ / tp_size_, data_type, false, data_type, 1);
register_module("tok_embeddings", pre_decoder_embedding, tp_rank_);
register_module("output", post_decoder_embedding, tp_rank_);

post_decoder_embedding = LlamaDenseWeight{hidden_units_, vocab_size_padded_ / tp_size_, data_type, data_type, 1};
post_decoder_embedding.malloc();
decoder_layer_weights.reserve(num_layer_);
for (int i = 0; i < num_layer_; ++i) {
decoder_layer_weights.emplace_back(
new LlamaDecoderLayerWeight(data_type, i, model, engine_param, lora_param, moe_param));
register_module("layers", *decoder_layer_weights.back(), i);
}

output_norm_weight = core::Buffer{hidden_units_, data_type_, MEMORY_GPU};
output_norm_weight = core::Tensor{{hidden_units_}, data_type_, MEMORY_GPU};
register_parameter("norm.weight", output_norm_weight);
}

LlamaWeight::~LlamaWeight()
@@ -90,7 +90,6 @@ LlamaWeight::~LlamaWeight()
output_norm_weight = {};

for (auto& p : decoder_layer_weights) {
p->free();
delete p;
}

@@ -105,27 +104,6 @@ core::ContextGuard LlamaWeight::context() const
return core::ContextGuard{stream_, alloca_};
}

core::TensorMap LlamaWeight::getParams()
{
core::TensorMap output;

output.emplace("tok_embeddings." + std::to_string(tp_rank_) + ".weight", pre_decoder_embedding.weight);
output.emplace("output." + std::to_string(tp_rank_) + ".weight", post_decoder_embedding.weight);

output.emplace("norm.weight", output_norm_weight);

// transformer layers
for (size_t i = 0; i < num_layer_; i++) {
std::string prefix = fmtstr("layers.%d", i);
core::TensorMap layer = decoder_layer_weights[i]->getParams(prefix);
for (auto& kv : layer) {
output.insert(std::move(kv));
}
}

return output;
}

void LlamaWeight::prepare(const cudaDeviceProp& prop)
{
core::ContextGuard guard = context();
6 changes: 2 additions & 4 deletions src/turbomind/models/llama/LlamaWeight.h
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@

namespace turbomind {

struct LlamaWeight {
struct LlamaWeight: core::Module {
LlamaWeight() = default;

LlamaWeight(DataType data_type,
@@ -41,8 +41,6 @@ struct LlamaWeight {
LlamaWeight(const LlamaWeight&) = delete;
LlamaWeight& operator=(const LlamaWeight&) = delete;

core::TensorMap getParams();

void prepare(const cudaDeviceProp& prop);

core::ContextGuard context() const;
@@ -52,7 +50,7 @@ struct LlamaWeight {
LlamaDenseWeight pre_decoder_embedding;
LlamaDenseWeight post_decoder_embedding;

core::Buffer output_norm_weight;
core::Tensor output_norm_weight;

private:
int hidden_units_;
4 changes: 2 additions & 2 deletions src/turbomind/models/llama/moe_ffn_layer.cc
Original file line number Diff line number Diff line change
@@ -170,10 +170,10 @@ void MoeFfnLayer::Forward(ForwardParam& p)
}

for (int i = 0; i < expert_num; ++i) {
FT_CHECK(moe.experts[i].is_fused_silu == false);
FT_CHECK(moe.experts[i]->is_fused_silu == false);
if (int count = h_offsets_[i + 1] - h_offsets_[i]) {
auto io = p.temp.slice({h_offsets_[i], 0}, {count, -1});
expert_ffn_->forward({io, io, &moe.experts.at(i), p.layer_id});
expert_ffn_->forward({io, io, moe.experts.at(i).get(), p.layer_id});
}
}
}
8 changes: 4 additions & 4 deletions src/turbomind/models/llama/unified_attention_layer.cc
Original file line number Diff line number Diff line change
@@ -109,7 +109,7 @@ struct ForwardParam {
core::Buffer_<int> cu_block_nums;
core::Buffer_<uintptr_t> kv_block_ptrs;

const void* weights;
const LlamaAttentionWeight* weights;

core::Event event;

@@ -123,7 +123,7 @@ void Initialize(ForwardParam& p, core::TensorMap& args, const core::Tensor& inpu
p.Init(args, input, output);
}

void SetLayer(ForwardParam& p, const void* weights, int layer_id)
void SetLayer(ForwardParam& p, const LlamaAttentionWeight* weights, int layer_id)
{
p.weights = weights;
p.layer_id = layer_id;
@@ -235,7 +235,7 @@ void UnifiedAttentionLayer::forward(ForwardParam& p)

const int layer_id = p.layer_id;

const auto& weights = *(const WeightType*)p.weights;
const auto& weights = *p.weights;

// [L, 2, H, s, D]
const size_t layer_offset = layer_id * 2 * local_kv_head_num_ * param_.cache_block_seq_len * size_per_head_;
@@ -313,7 +313,7 @@ core::Tensor UnifiedAttentionLayer::core_attention(core::Tensor& qkv, const Forw
params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;

if (weights.qkv.bias) {
params.q_bias = weights.qkv.bias.unsafe_data<T>();
params.q_bias = weights.qkv.bias.buffer().unsafe_data<T>();
params.k_bias = params.q_bias + local_head_num_ * size_per_head_;
params.v_bias = params.k_bias + local_kv_head_num_ * size_per_head_;
}
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/unified_attention_layer.h
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ struct ForwardParam;

void Initialize(ForwardParam& p, core::TensorMap& args, const core::Tensor& input, core::Tensor& output);

void SetLayer(ForwardParam& p, const void* weights, int layer_id);
void SetLayer(ForwardParam& p, const LlamaAttentionWeight* weights, int layer_id);

void Finalize(ForwardParam& p);

29 changes: 13 additions & 16 deletions src/turbomind/models/llama/unified_decoder.cc
Original file line number Diff line number Diff line change
@@ -55,8 +55,8 @@ UnifiedDecoder::~UnifiedDecoder() = default;

void UnifiedDecoder::AllreduceResidualRMSnorm(core::Tensor& hidden_states,
core::Tensor& residual,
const core::Buffer& bias,
const core::Buffer& weight,
const core::Tensor& bias,
const core::Tensor& weight,
int token_num,
int group0,
int group1,
@@ -68,7 +68,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(core::Tensor& hidden_states,
else if (group0 || group1) {
d_comm_->AllreduceResidualBiasRMSnormEx(hidden_states.raw_data(),
residual.raw_data(),
bias.unsafe_data(),
bias.buffer().unsafe_data(),
weight.raw_data(),
rmsnorm_eps_,
hidden_units_,
@@ -82,7 +82,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(core::Tensor& hidden_states,
else if (d_comm_) {
d_comm_->AllreduceResidualBiasRMSnorm(hidden_states.raw_data(),
residual.raw_data(),
bias.unsafe_data(),
bias.buffer().unsafe_data(),
weight.raw_data(),
rmsnorm_eps_,
hidden_units_,
@@ -96,7 +96,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(core::Tensor& hidden_states,
invokeResidualBiasRMSNorm(hidden_states.raw_data(),
residual.raw_data(),
weight.raw_data(),
bias.unsafe_data(),
bias.buffer().unsafe_data(),
dtype,
hidden_units_,
token_num,
@@ -170,14 +170,14 @@ void UnifiedDecoder::Forward(core::TensorMap& args, const std::vector<WeightType

/////////////////////////////////////////////
/// self-attention
SetLayer(*attn_fwd_param_, &weights.at(layer)->self_attn_weights, layer);
SetLayer(*attn_fwd_param_, weights.at(layer)->self_attn_weights.get(), layer);
attn_layer_->forward(*attn_fwd_param_);

TM_DEBUG_TENSOR(local_hidden_states, Concat("attn_block", layer), 1);

AllreduceResidualRMSnorm(global_hidden_states,
local_residual,
weights.at(layer)->self_attn_weights.output.bias,
weights.at(layer)->self_attn_weights->output.bias,
weights.at(layer)->ffn_norm,
local_token_num,
attn_tp_group_,
@@ -192,22 +192,22 @@ void UnifiedDecoder::Forward(core::TensorMap& args, const std::vector<WeightType

std::optional<MoeFfnLayer::ForwardParam> moe_fwd_param;

if (!weights.at(layer)->moe_weights.experts.empty()) {
if (weights.at(layer)->moe_weights) {
moe_fwd_param = MoeFfnLayer::ForwardParam{global_hidden_states,
global_hidden_states,
{},
ffn_layer_ ? 1.f : 0.f,
(int)layer,
&weights.at(layer)->moe_weights};
weights.at(layer)->moe_weights.get()};
}

if (moe_fwd_param) {
moe_ffn_layer_->Forward(*moe_fwd_param);
}

if (weights.at(layer)->ffn_weights.output.weight) {
if (weights.at(layer)->ffn_weights) {
ffn_layer_->forward(
{global_hidden_states, global_hidden_states, &weights.at(layer)->ffn_weights, (int)layer});
{global_hidden_states, global_hidden_states, weights.at(layer)->ffn_weights.get(), (int)layer});
}

if (moe_fwd_param) {
@@ -218,11 +218,11 @@ void UnifiedDecoder::Forward(core::TensorMap& args, const std::vector<WeightType

const bool last = layer == layer_num_ - 1;

auto& scale_weight = !last ? weights.at(layer + 1)->self_attn_norm : args.at("output_norm_weight").buffer();
auto& scale_weight = !last ? weights.at(layer + 1)->self_attn_norm : args.at("output_norm_weight");

AllreduceResidualRMSnorm(global_hidden_states,
local_residual,
weights.at(layer)->ffn_weights.output.bias,
weights.at(layer)->ffn_weights->output.bias,
scale_weight,
local_token_num,
0,
@@ -260,9 +260,6 @@ void UnifiedDecoder::Forward(core::TensorMap& args, const std::vector<WeightType
}

Finalize(*attn_fwd_param_);

// core::Context::stream().Sync();
// TM_CHECK(0);
}

} // namespace turbomind
4 changes: 2 additions & 2 deletions src/turbomind/models/llama/unified_decoder.h
Original file line number Diff line number Diff line change
@@ -43,8 +43,8 @@ class UnifiedDecoder {

void AllreduceResidualRMSnorm(core::Tensor& hidden_states,
core::Tensor& residual,
const core::Buffer& bias,
const core::Buffer& weight,
const core::Tensor& bias,
const core::Tensor& weight,
int token_num,
int t0,
int t1,
4 changes: 1 addition & 3 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
@@ -435,9 +435,7 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank)

core::TensorMap LlamaTritonModel::getParams(int device_id, int rank)
{
check_cuda_error(cudaSetDevice(device_id));

return TM_CHECK_NOTNULL(weights_[rank])->getParams();
return TM_CHECK_NOTNULL(weights_[rank])->get_parameters();
}

void LlamaTritonModel::processWeights(int device_id, int rank)