Description
Checklist
- 1. I have searched related issues but cannot get the expected help.2. The bug has not been fixed in the latest version.3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.5. Please use English, otherwise it will be closed.
Describe the bug
[2025-05-12 05:36:48] server_args=ServerArgs(model_path='/nfsshare/model-checkpoint/Qwen3-235B-A22B-AWQ/', tokenizer_path='/nfsshare/model-checkpoint/Qwen3-235B-A22B-AWQ/', tokenizer_mode='auto', skip_tokenizer_init=False, enable_tokenizer_batch_encode=False, load_format='auto', trust_remote_code=True, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='qwen3-235b', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=30000, mem_fraction_static=0.85, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=4, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=740248602, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=True, decode_log_interval=40, api_key='6fcf3eaf8297ab66c9bc76e54920c8ec', file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser='qwen3', dp_size=1, load_balance_method='round_robin', ep_size=4, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768}}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=True, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser='qwen25', enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None)
INFO 05-12 05:36:48 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:55 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:55 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:55 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:55 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:56 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-05-12 05:36:56 TP3] Attention backend not set. Use flashinfer backend by default.
[2025-05-12 05:36:56 TP3] Init torch distributed begin.
INFO 05-12 05:36:56 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-05-12 05:36:56 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-05-12 05:36:56 TP0] Init torch distributed begin.
INFO 05-12 05:36:56 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-05-12 05:36:56 TP2] Attention backend not set. Use flashinfer backend by default.
[2025-05-12 05:36:56 TP2] Init torch distributed begin.
INFO 05-12 05:36:56 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-05-12 05:36:56 TP1] Attention backend not set. Use flashinfer backend by default.
[2025-05-12 05:36:56 TP1] Init torch distributed begin.
[2025-05-12 05:36:56 TP0] sglang is using nccl==2.21.5
[2025-05-12 05:36:56 TP2] sglang is using nccl==2.21.5
[2025-05-12 05:36:56 TP3] sglang is using nccl==2.21.5
[2025-05-12 05:36:56 TP1] sglang is using nccl==2.21.5
[2025-05-12 05:36:57 TP3] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-05-12 05:36:57 TP2] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-05-12 05:36:57 TP0] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-05-12 05:36:57 TP1] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-05-12 05:36:57 TP1] Init torch distributed ends. mem usage=0.40 GB
[2025-05-12 05:36:57 TP2] Init torch distributed ends. mem usage=0.40 GB
[2025-05-12 05:36:57 TP0] Init torch distributed ends. mem usage=0.38 GB
[2025-05-12 05:36:57 TP3] Init torch distributed ends. mem usage=0.32 GB
[2025-05-12 05:36:57 TP1] Load weight begin. avail mem=43.49 GB
[2025-05-12 05:36:57 TP0] Load weight begin. avail mem=43.50 GB
[2025-05-12 05:36:57 TP2] Load weight begin. avail mem=43.49 GB
[2025-05-12 05:36:57 TP3] Load weight begin. avail mem=43.57 GB
[2025-05-12 05:36:58 TP1] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2219, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 272, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 64, in init
self.worker = TpModelWorker(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 85, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 190, in init
self.initialize(min_per_gpu_memory)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 205, in initialize
self.load_model()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 458, in load_model
self.model = get_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/init.py", line 22, in get_model
return loader.load_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 372, in load_model
model = _initialize_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 153, in _initialize_model
return model_class(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 328, in init
self.model = Qwen3MoeModel(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 307, in init
super().init(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 352, in init
self.layers = make_layers(
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 440, in make_layers
+ [
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 441, in
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 354, in
lambda idx, prefix: decoder_layer_type(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 257, in init
self.mlp = Qwen3MoeSparseMoeBlock(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 80, in init
self.experts = MoEImpl(
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 177, in init
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 465, in new
obj.init(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 581, in init
self.block_quant = self.quant_config.weight_block_size is not None
AttributeError: 'AWQMarlinConfig' object has no attribute 'weight_block_size'
[2025-05-12 05:36:58 TP2] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2219, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 272, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 64, in init
self.worker = TpModelWorker(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 85, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 190, in init
self.initialize(min_per_gpu_memory)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 205, in initialize
self.load_model()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 458, in load_model
self.model = get_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/init.py", line 22, in get_model
return loader.load_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 372, in load_model
model = _initialize_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 153, in _initialize_model
return model_class(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 328, in init
self.model = Qwen3MoeModel(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 307, in init
super().init(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 352, in init
self.layers = make_layers(
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 440, in make_layers
+ [
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 441, in
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 354, in
lambda idx, prefix: decoder_layer_type(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 257, in init
self.mlp = Qwen3MoeSparseMoeBlock(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 80, in init
self.experts = MoEImpl(
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 177, in init
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 465, in new
obj.init(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 581, in init
self.block_quant = self.quant_config.weight_block_size is not None
AttributeError: 'AWQMarlinConfig' object has no attribute 'weight_block_size'
[2025-05-12 05:36:58 TP0] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2219, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 272, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 64, in init
self.worker = TpModelWorker(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 85, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 190, in init
self.initialize(min_per_gpu_memory)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 205, in initialize
self.load_model()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 458, in load_model
self.model = get_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/init.py", line 22, in get_model
return loader.load_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 372, in load_model
model = _initialize_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 153, in _initialize_model
return model_class(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 328, in init
self.model = Qwen3MoeModel(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 307, in init
super().init(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 352, in init
self.layers = make_layers(
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 440, in make_layers
+ [
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 441, in
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 354, in
lambda idx, prefix: decoder_layer_type(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 257, in init
self.mlp = Qwen3MoeSparseMoeBlock(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 80, in init
self.experts = MoEImpl(
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 177, in init
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 465, in new
obj.init(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 581, in init
self.block_quant = self.quant_config.weight_block_size is not None
AttributeError: 'AWQMarlinConfig' object has no attribute 'weight_block_size'
[2025-05-12 05:36:58] Received sigquit from a child process. It usually means the child failed.
Reproduction
I use sglang:v0.4.6-post2 as the base image, then run pip install vllm==0.8.5 to create a new image sglang:v0.4.6-post2-fixed. Below is the specific docker-compose file for deployment:
docker-compose.yml
services:
sglang-qwen-moe-235b-1:
image: sglang:v0.4.6-post2-fixed
container_name: sglang-qwen-moe-235b-1
volumes:
- /etc/hosts:/etc/hosts
- nfsshare:/nfsshare:ro
restart: always
ports:
- 11020:30000
entrypoint: /bin/bash
command:
- -c
- |
python3 -m sglang.launch_server \
--model-path /nfsshare/model-checkpoint/Qwen3-235B-A22B-AWQ/ \
--trust-remote-code \
--served-model-name qwen3-235b \
--api-key xxx \
--tensor-parallel-size 4 \
--mem-fraction-static 0.85 \
--quantization awq_marlin \
--enable-metrics \
--host 0.0.0.0 \
--port 30000 \
--reasoning-parser qwen3 \
--tool-call-parser qwen25 \
--enable-ep-moe \
--json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768}}'
ulimits:
memlock: -1
stack: 67108864
shm_size: '30gb'
ipc: host
healthcheck:
test: ["CMD-SHELL", "curl -f http://localhost:30000/health || exit 1"]
interval: 60s
retries: 1
start_period: 800s
timeout: 1000s
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0', '1', '2', '3']
capabilities: [gpu]
volumes:
nfsshare:
external: true
name: nfsshare
Environment
python3 -m sglang.launch_server \
--model-path /nfsshare/model-checkpoint/Qwen3-235B-A22B-AWQ/ \
--trust-remote-code \
--served-model-name qwen3-235b \
--api-key xxx \
--tensor-parallel-size 4 \
--mem-fraction-static 0.85 \
--quantization awq_marlin \
--enable-metrics \
--host 0.0.0.0 \
--port 30000 \
--reasoning-parser qwen3 \
--tool-call-parser qwen25 \
--enable-ep-moe \
--json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768}}'
Activity
EvanSong77 commentedon May 19, 2025
@zhyncs Please take a look at this question.
zty-wangli commentedon May 21, 2025
I've encountered the same issue when working with the Deepseek-V3-AWQ model.
thesillystudent commentedon Jun 19, 2025
Hey, were you able to resolve the issue ?
zty-wangli commentedon Jun 20, 2025
maybe it's useful
#6654 (comment)
EvanSong77 commentedon Jun 21, 2025
don't use
--enable-ep-moe