6
6
from torch import nn
7
7
from transformers .configuration_utils import PretrainedConfig
8
8
9
- from lmdeploy .pytorch .distributed import get_tp_world_rank
9
+ from lmdeploy .pytorch .distributed import get_dist_manager , get_ep_world_rank , get_tp_world_rank
10
10
from lmdeploy .pytorch .model_inputs import StepContext , StepContextManager
11
11
from lmdeploy .pytorch .nn import ApplyRotaryEmb , Attention , RMSNorm , RopeType , SiluAndMul , build_rotary_embedding
12
12
from lmdeploy .pytorch .nn .linear import build_merged_colwise_linear , build_qkv_proj , build_rowwise_linear
13
13
from lmdeploy .pytorch .nn .moe import SoftmaxTopK , build_fused_moe
14
14
from lmdeploy .pytorch .weight_loader .model_weight_loader import load_weight
15
+ from lmdeploy .utils import get_logger
15
16
16
17
from .utils .cudagraph import CudaGraphMixin
17
18
19
+ logger = get_logger ('lmdeploy' )
20
+
21
+ try :
22
+ from dlblas .layers .moe import eplb
23
+ use_dlblas = True
24
+ except Exception :
25
+ use_dlblas = False
26
+ logger .warning ('For higher performance, please install dlBLAS https://github.com/DeepLink-org/dlBLAS' )
27
+
18
28
19
29
class Qwen3MoeAttention (nn .Module ):
20
30
"""Rewrite module of Qwen3MoeAttention."""
@@ -199,7 +209,13 @@ def __init__(self,
199
209
200
210
world_size , _ = get_tp_world_rank ()
201
211
_all_reduce = world_size > 1
202
-
212
+ if get_dist_manager ().current_context ().dist_config .enable_eplb :
213
+ dist_ctx = get_dist_manager ().current_context ()
214
+ self .eplb_dispatch_info = eplb .EPLBDispatchInfo .init_new (
215
+ ep_rank = dist_ctx .ep_rank ,
216
+ layer_idx = layer_idx ,
217
+ )
218
+ self .num_experts = eplb .get_global_eplb_metadata ().num_physical_experts ()
203
219
self .experts = build_fused_moe (
204
220
self .hidden_dim ,
205
221
self .ffn_dim ,
@@ -210,16 +226,17 @@ def __init__(self,
210
226
device = device ,
211
227
quant_config = quantization_config ,
212
228
all_reduce = _all_reduce ,
229
+ layer_idx = layer_idx ,
213
230
)
214
231
215
232
def forward (self , hidden_states : torch .Tensor ):
216
233
"""forward."""
217
234
batch_size , sequence_length , hidden_dim = hidden_states .shape
218
235
hidden_states = hidden_states .view (- 1 , hidden_dim )
219
236
router_logits = self .gate (hidden_states )
220
-
221
237
topk_weights , topk_ids = self .softmax_topk (router_logits )
222
-
238
+ if get_dist_manager ().current_context ().dist_config .enable_eplb :
239
+ topk_ids = eplb .topk_ids_logical_to_physical (topk_ids , self .eplb_dispatch_info )
223
240
out_states = self .experts (
224
241
hidden_states ,
225
242
topk_weights ,
@@ -307,6 +324,15 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
307
324
dtype = dtype ,
308
325
device = device )
309
326
327
+ if get_dist_manager ().current_context ().dist_config .enable_eplb :
328
+ if not use_dlblas :
329
+ raise ImportError ('To enable eplb, please install dlBLAS https://github.com/DeepLink-org/dlBLAS' )
330
+ ep_size , _ = get_ep_world_rank ()
331
+ eplb .init_global_eplb_metadata (
332
+ ep_size = ep_size ,
333
+ num_routed_experts = config .num_experts ,
334
+ num_hidden_layers = config .num_hidden_layers ,
335
+ )
310
336
# build all decode layers
311
337
self .layers = nn .ModuleList ([
312
338
Qwen3MoeDecoderLayer (config , layer_idx , dtype = dtype , device = device )
0 commit comments