Skip to content

[Bug] AttributeError: 'AWQMarlinConfig' object has no attribute 'weight_block_size' when deploying Qwen3-235B-A22B-AWQ #6234

Closed
@EvanSong77

Description

@EvanSong77

Checklist

  • 1. I have searched related issues but cannot get the expected help.
    2. The bug has not been fixed in the latest version.
    3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
    4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
    5. Please use English, otherwise it will be closed.

Describe the bug

[2025-05-12 05:36:48] server_args=ServerArgs(model_path='/nfsshare/model-checkpoint/Qwen3-235B-A22B-AWQ/', tokenizer_path='/nfsshare/model-checkpoint/Qwen3-235B-A22B-AWQ/', tokenizer_mode='auto', skip_tokenizer_init=False, enable_tokenizer_batch_encode=False, load_format='auto', trust_remote_code=True, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='qwen3-235b', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=30000, mem_fraction_static=0.85, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=4, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=740248602, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=True, decode_log_interval=40, api_key='6fcf3eaf8297ab66c9bc76e54920c8ec', file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser='qwen3', dp_size=1, load_balance_method='round_robin', ep_size=4, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768}}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=True, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser='qwen25', enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None)
INFO 05-12 05:36:48 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:52 [init.py:239] Automatically detected platform cuda.
INFO 05-12 05:36:55 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:55 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:55 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:55 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 05-12 05:36:56 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-05-12 05:36:56 TP3] Attention backend not set. Use flashinfer backend by default.
[2025-05-12 05:36:56 TP3] Init torch distributed begin.
INFO 05-12 05:36:56 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-05-12 05:36:56 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-05-12 05:36:56 TP0] Init torch distributed begin.
INFO 05-12 05:36:56 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-05-12 05:36:56 TP2] Attention backend not set. Use flashinfer backend by default.
[2025-05-12 05:36:56 TP2] Init torch distributed begin.
INFO 05-12 05:36:56 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-05-12 05:36:56 TP1] Attention backend not set. Use flashinfer backend by default.
[2025-05-12 05:36:56 TP1] Init torch distributed begin.
[2025-05-12 05:36:56 TP0] sglang is using nccl==2.21.5
[2025-05-12 05:36:56 TP2] sglang is using nccl==2.21.5
[2025-05-12 05:36:56 TP3] sglang is using nccl==2.21.5
[2025-05-12 05:36:56 TP1] sglang is using nccl==2.21.5
[2025-05-12 05:36:57 TP3] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-05-12 05:36:57 TP2] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-05-12 05:36:57 TP0] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-05-12 05:36:57 TP1] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-05-12 05:36:57 TP1] Init torch distributed ends. mem usage=0.40 GB
[2025-05-12 05:36:57 TP2] Init torch distributed ends. mem usage=0.40 GB
[2025-05-12 05:36:57 TP0] Init torch distributed ends. mem usage=0.38 GB
[2025-05-12 05:36:57 TP3] Init torch distributed ends. mem usage=0.32 GB
[2025-05-12 05:36:57 TP1] Load weight begin. avail mem=43.49 GB
[2025-05-12 05:36:57 TP0] Load weight begin. avail mem=43.50 GB
[2025-05-12 05:36:57 TP2] Load weight begin. avail mem=43.49 GB
[2025-05-12 05:36:57 TP3] Load weight begin. avail mem=43.57 GB
[2025-05-12 05:36:58 TP1] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2219, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 272, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 64, in init
self.worker = TpModelWorker(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 85, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 190, in init
self.initialize(min_per_gpu_memory)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 205, in initialize
self.load_model()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 458, in load_model
self.model = get_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/init.py", line 22, in get_model
return loader.load_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 372, in load_model
model = _initialize_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 153, in _initialize_model
return model_class(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 328, in init
self.model = Qwen3MoeModel(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 307, in init
super().init(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 352, in init
self.layers = make_layers(
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 440, in make_layers
+ [
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 441, in
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 354, in
lambda idx, prefix: decoder_layer_type(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 257, in init
self.mlp = Qwen3MoeSparseMoeBlock(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 80, in init
self.experts = MoEImpl(
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 177, in init
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 465, in new
obj.init(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 581, in init
self.block_quant = self.quant_config.weight_block_size is not None
AttributeError: 'AWQMarlinConfig' object has no attribute 'weight_block_size'
[2025-05-12 05:36:58 TP2] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2219, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 272, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 64, in init
self.worker = TpModelWorker(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 85, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 190, in init
self.initialize(min_per_gpu_memory)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 205, in initialize
self.load_model()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 458, in load_model
self.model = get_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/init.py", line 22, in get_model
return loader.load_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 372, in load_model
model = _initialize_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 153, in _initialize_model
return model_class(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 328, in init
self.model = Qwen3MoeModel(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 307, in init
super().init(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 352, in init
self.layers = make_layers(
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 440, in make_layers
+ [
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 441, in
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 354, in
lambda idx, prefix: decoder_layer_type(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 257, in init
self.mlp = Qwen3MoeSparseMoeBlock(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 80, in init
self.experts = MoEImpl(
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 177, in init
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 465, in new
obj.init(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 581, in init
self.block_quant = self.quant_config.weight_block_size is not None
AttributeError: 'AWQMarlinConfig' object has no attribute 'weight_block_size'
[2025-05-12 05:36:58 TP0] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2219, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 272, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 64, in init
self.worker = TpModelWorker(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 85, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 190, in init
self.initialize(min_per_gpu_memory)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 205, in initialize
self.load_model()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 458, in load_model
self.model = get_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/init.py", line 22, in get_model
return loader.load_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 372, in load_model
model = _initialize_model(
File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 153, in _initialize_model
return model_class(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 328, in init
self.model = Qwen3MoeModel(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 307, in init
super().init(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 352, in init
self.layers = make_layers(
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 440, in make_layers
+ [
File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 441, in
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen2_moe.py", line 354, in
lambda idx, prefix: decoder_layer_type(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 257, in init
self.mlp = Qwen3MoeSparseMoeBlock(
File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 80, in init
self.experts = MoEImpl(
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 177, in init
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 465, in new
obj.init(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 581, in init
self.block_quant = self.quant_config.weight_block_size is not None
AttributeError: 'AWQMarlinConfig' object has no attribute 'weight_block_size'
[2025-05-12 05:36:58] Received sigquit from a child process. It usually means the child failed.

Reproduction

I use sglang:v0.4.6-post2 as the base image, then run pip install vllm==0.8.5 to create a new image sglang:v0.4.6-post2-fixed. Below is the specific docker-compose file for deployment:

docker-compose.yml

services:
  sglang-qwen-moe-235b-1:
    image: sglang:v0.4.6-post2-fixed
    container_name: sglang-qwen-moe-235b-1
    volumes:
      - /etc/hosts:/etc/hosts
      - nfsshare:/nfsshare:ro
    restart: always
    ports:
      - 11020:30000
    entrypoint: /bin/bash
    command:
      - -c
      - |
        python3 -m sglang.launch_server \
        --model-path /nfsshare/model-checkpoint/Qwen3-235B-A22B-AWQ/ \
        --trust-remote-code \
        --served-model-name qwen3-235b \
        --api-key xxx \
        --tensor-parallel-size 4 \
        --mem-fraction-static 0.85 \
        --quantization awq_marlin \
        --enable-metrics \
        --host 0.0.0.0 \
        --port 30000 \
        --reasoning-parser qwen3 \
        --tool-call-parser qwen25 \
        --enable-ep-moe \
        --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768}}'

    ulimits:
      memlock: -1
      stack: 67108864
    shm_size: '30gb'
    ipc: host
    healthcheck:
      test: ["CMD-SHELL", "curl -f http://localhost:30000/health || exit 1"]
      interval: 60s
      retries: 1
      start_period: 800s
      timeout: 1000s
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              device_ids: ['0', '1', '2', '3']
              capabilities: [gpu]


volumes:
  nfsshare:
    external: true
    name: nfsshare

Environment

python3 -m sglang.launch_server \
        --model-path /nfsshare/model-checkpoint/Qwen3-235B-A22B-AWQ/ \
        --trust-remote-code \
        --served-model-name qwen3-235b \
        --api-key xxx \
        --tensor-parallel-size 4 \
        --mem-fraction-static 0.85 \
        --quantization awq_marlin \
        --enable-metrics \
        --host 0.0.0.0 \
        --port 30000 \
        --reasoning-parser qwen3 \
        --tool-call-parser qwen25 \
        --enable-ep-moe \
        --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768}}'

Activity

EvanSong77

EvanSong77 commented on May 19, 2025

@EvanSong77
Author

@zhyncs Please take a look at this question.

zty-wangli

zty-wangli commented on May 21, 2025

@zty-wangli

I've encountered the same issue when working with the Deepseek-V3-AWQ model.

thesillystudent

thesillystudent commented on Jun 19, 2025

@thesillystudent

Hey, were you able to resolve the issue ?

zty-wangli

zty-wangli commented on Jun 20, 2025

@zty-wangli

Hey, were you able to resolve the issue ?

maybe it's useful
#6654 (comment)

EvanSong77

EvanSong77 commented on Jun 21, 2025

@EvanSong77
Author

don't use --enable-ep-moe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @thesillystudent@EvanSong77@zty-wangli

        Issue actions

          [Bug] AttributeError: 'AWQMarlinConfig' object has no attribute 'weight_block_size' when deploying Qwen3-235B-A22B-AWQ · Issue #6234 · sgl-project/sglang