Skip to content

Commit e22d683

Browse files
committed
support eplb
1 parent 88f3dde commit e22d683

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

lmdeploy/pytorch/models/qwen3_moe.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,25 @@
66
from torch import nn
77
from transformers.configuration_utils import PretrainedConfig
88

9-
from lmdeploy.pytorch.distributed import get_tp_world_rank
9+
from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank
1010
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
1111
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding
1212
from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear
1313
from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
1414
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
15+
from lmdeploy.utils import get_logger
1516

1617
from .utils.cudagraph import CudaGraphMixin
1718

19+
logger = get_logger('lmdeploy')
20+
21+
try:
22+
from dlblas.layers.moe import eplb
23+
use_dlblas = True
24+
except Exception:
25+
use_dlblas = False
26+
logger.warning('For higher performance, please install dlBLAS https://github.com/DeepLink-org/dlBLAS')
27+
1828

1929
class Qwen3MoeAttention(nn.Module):
2030
"""Rewrite module of Qwen3MoeAttention."""
@@ -199,7 +209,13 @@ def __init__(self,
199209

200210
world_size, _ = get_tp_world_rank()
201211
_all_reduce = world_size > 1
202-
212+
if get_dist_manager().current_context().dist_config.enable_eplb:
213+
dist_ctx = get_dist_manager().current_context()
214+
self.eplb_dispatch_info = eplb.EPLBDispatchInfo.init_new(
215+
ep_rank=dist_ctx.ep_rank,
216+
layer_idx=layer_idx,
217+
)
218+
self.num_experts = eplb.get_global_eplb_metadata().num_physical_experts()
203219
self.experts = build_fused_moe(
204220
self.hidden_dim,
205221
self.ffn_dim,
@@ -210,16 +226,17 @@ def __init__(self,
210226
device=device,
211227
quant_config=quantization_config,
212228
all_reduce=_all_reduce,
229+
layer_idx=layer_idx,
213230
)
214231

215232
def forward(self, hidden_states: torch.Tensor):
216233
"""forward."""
217234
batch_size, sequence_length, hidden_dim = hidden_states.shape
218235
hidden_states = hidden_states.view(-1, hidden_dim)
219236
router_logits = self.gate(hidden_states)
220-
221237
topk_weights, topk_ids = self.softmax_topk(router_logits)
222-
238+
if get_dist_manager().current_context().dist_config.enable_eplb:
239+
topk_ids = eplb.topk_ids_logical_to_physical(topk_ids, self.eplb_dispatch_info)
223240
out_states = self.experts(
224241
hidden_states,
225242
topk_weights,
@@ -307,6 +324,15 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
307324
dtype=dtype,
308325
device=device)
309326

327+
if get_dist_manager().current_context().dist_config.enable_eplb:
328+
if not use_dlblas:
329+
raise ImportError('To enable eplb, please install dlBLAS https://github.com/DeepLink-org/dlBLAS')
330+
ep_size, _ = get_ep_world_rank()
331+
eplb.init_global_eplb_metadata(
332+
ep_size=ep_size,
333+
num_routed_experts=config.num_experts,
334+
num_hidden_layers=config.num_hidden_layers,
335+
)
310336
# build all decode layers
311337
self.layers = nn.ModuleList([
312338
Qwen3MoeDecoderLayer(config, layer_idx, dtype=dtype, device=device)

0 commit comments

Comments
 (0)