Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meet error in serving with huggingface inference tutorial #16

Closed
JF-D opened this issue Apr 28, 2024 · 19 comments
Closed

Meet error in serving with huggingface inference tutorial #16

JF-D opened this issue Apr 28, 2024 · 19 comments

Comments

@JF-D
Copy link

JF-D commented Apr 28, 2024

Hi, Arctic team, Great work! I followed the Huggingface Inference Tutorial to do the inference. But I met the following error:

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [24:34<00:00,  7.56s/it]
WARNING:root:Some parameters are on the meta device device because they were offloaded to the disk.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:31999 for open-end generation.
Traceback (most recent call last):
  File "/mnt/afs/jfduan/LLMInfer/snowflake-arctic/inference/hf_infer.py", line 28, in <module>
    outputs = model.generate(input_ids=input_ids, max_new_tokens=20)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 1572, in generate
    result = self._greedy_search(
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 2477, in _greedy_search
    outputs = self(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1708, in forward
    outputs = self.model(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1397, in forward
    layer_outputs = decoder_layer(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1087, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 808, in forward
    query_states = self.q_proj(hidden_states)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 161, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 347, in pre_forward
    set_module_tensor_to_device(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 358, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([7168, 7168]) in "weight" (which has shape torch.Size([100352, 516])), this look incorrect.

Can you help me resolve this? Thanks a lot!

@jeffra
Copy link
Collaborator

jeffra commented Apr 28, 2024

Hi @JF-D! Thanks for trying this out. Can you tell me a bit more about your setup? Specifically:

  1. total number and type of GPUs
  2. transformers and deepspeed versions
  3. Did you make any changes to the example code? If so can you paste it here?

@JF-D
Copy link
Author

JF-D commented Apr 29, 2024

  1. I am using 8xA100 GPUs.
  2. transformers==4.40.0.dev0, deepspeed==0.14.2
    I followed the following instructions to install the deps:
# we recommend setting up a virtual environment for this
virtualenv arctic-venv
source arctic-venv/bin/activate

# faster ckpt download speed
pip install huggingface_hub[hf_transfer]

# clone vllm repo and checkout arctic branch
git clone -b arctic https://github.com/Snowflake-Labs/vllm.git
cd vllm
pip install -e .

# clone Hugging Face and checkout arctic branch
git clone -b arctic https://github.com/Snowflake-Labs/transformers.git

# install deepspeed
pip install deepspeed>=0.14.2
  1. I didn't change the example code. The full code is as listed,
import os
# enable hf_transfer for faster ckpt download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from deepspeed.linear.config import QuantizationConfig

tokenizer = AutoTokenizer.from_pretrained(
    "Snowflake/snowflake-arctic-instruct",
    trust_remote_code=True
)

quant_config = QuantizationConfig(q_bits=8)

model = AutoModelForCausalLM.from_pretrained(
    "Snowflake/snowflake-arctic-instruct",
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map="auto",
    ds_quantization_config=quant_config,
    max_memory={i: "150GiB" for i in range(8)},
    torch_dtype=torch.bfloat16)

messages = [{"role": "user", "content": "What is 1 + 1 "}]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to("cuda")

outputs = model.generate(input_ids=input_ids, max_new_tokens=20)
print(tokenizer.decode(outputs[0]))

BTW, loading checkpoints takes ~30min on my server, it's soooo long.

@jeffra
Copy link
Collaborator

jeffra commented Apr 29, 2024

Excellent, one quick follow-up question before diving into the other details. Are these 40GB or 80GB A100s?

w.r.t. slow load times, we are working on uploading pre-quantized checkpoints to HF. Hopefully that will help reduce the load times a bit.

@JF-D
Copy link
Author

JF-D commented Apr 29, 2024

They are 80GB A100s. I think with the quantization config, I should be able to run a simple example.

@jeffra
Copy link
Collaborator

jeffra commented Apr 29, 2024

Gotcha, yeah I think 8xA100-80GB should work here. We have not tested this exactly since I don't have immediate access to this hardware. I have seen that error message previously due to some tensors being moved to CPU by device_map="auto". This shouldn't happen if there's enough memory on the GPUs for everything though, which we have confirmed is the case with 8xH100-80GB.

@jeffra
Copy link
Collaborator

jeffra commented Apr 29, 2024

Also, if you haven’t already can a you try changing q_bits=6 in the quant config?

@JF-D
Copy link
Author

JF-D commented Apr 29, 2024

Ok! Let me have a try and then get back to you.

@JF-D
Copy link
Author

JF-D commented Apr 29, 2024

Unfortunately, setting q_bits=6 meets the same error

[2024-04-29 10:42:13,954] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
Using /mnt/afs/jfduan/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /mnt/afs/jfduan/.cache/torch_extensions/py310_cu121/fp_quantizer/build.ninja...
Building extension module fp_quantizer...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fp_quantizer...
Time to load fp_quantizer op: 0.3470158576965332 seconds
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [03:56<00:00,  1.22s/it]
WARNING:root:Some parameters are on the meta device device because they were offloaded to the disk.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:31999 for open-end generation.
Traceback (most recent call last):
  File "/mnt/afs/jfduan/LLMInfer/snowflake-arctic/inference/hf_infer.py", line 28, in <module>
    outputs = model.generate(input_ids=input_ids, max_new_tokens=20)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 1572, in generate
    result = self._greedy_search(
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 2477, in _greedy_search
    outputs = self(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1708, in forward
    outputs = self.model(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1397, in forward
    layer_outputs = decoder_layer(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1087, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 808, in forward
    query_states = self.q_proj(hidden_states)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 161, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 347, in pre_forward
    set_module_tensor_to_device(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 358, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([7168, 7168]) in "weight" (which has shape torch.Size([100352, 516])), this look incorrect.

@sfc-gh-reyazda
Copy link

Hi @JF-D,

Can you please try this PR?
Thanks.
Reza

@JF-D
Copy link
Author

JF-D commented Apr 30, 2024

Thanks! @sfc-gh-reyazda

I tried the PR you mentioned, and met the following error,

[2024-04-30 13:41:09,352] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
 [WARNING]  using untested triton version (2.3.0), only 1.0.0 is known to be compatible
Using /mnt/afs/jfduan/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /mnt/afs/jfduan/.cache/torch_extensions/py310_cu121/fp_quantizer/build.ninja...
/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module fp_quantizer...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fp_quantizer...
Time to load fp_quantizer op: 0.3507523536682129 seconds
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [14:56<00:00,  4.60s/it]
WARNING:root:Some parameters are on the meta device device because they were offloaded to the disk.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:31999 for open-end generation.
Traceback (most recent call last):
  File "/mnt/afs/jfduan/LLMInfer/snowflake-arctic/inference/hf_infer.py", line 29, in <module>
    outputs = model.generate(input_ids=input_ids, max_new_tokens=20)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 1572, in generate
    result = self._greedy_search(
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 2477, in _greedy_search
    outputs = self(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1708, in forward
    outputs = self.model(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1397, in forward
    layer_outputs = decoder_layer(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1087, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 808, in forward
    query_states = self.q_proj(hidden_states)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/deepspeed/linear/quantization.py", line 137, in forward
    return F.linear(input, self.weight.dequantized(), self.bias)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/deepspeed/linear/quantization.py", line 73, in dequantized
    return self.quantizer.dequantize(self.data,
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/deepspeed/ops/fp_quantizer/quantize.py", line 89, in dequantize
    assert (self.orig_dtype is not None), \
AssertionError: [De-quantization Error]: you need to call quantize before dequantizing!

@sfc-gh-reyazda
Copy link

This is very strange! This means that the quantizer with which you are trying to dequantize the weight does not have the self.orig_dtype set properly! and it only means that the quantizer of that weight was never called (otherwise, this should have been set here)! So, this suggests to me that we are probably using different versions of transformers as I am not able to repro the same issue as you see. I tried this on an older commit of snowflake-lab/transformers: 6b1fe691bf8c34318f1beb5124db1162d93f047e
which branch/commit are you using?

@JF-D
Copy link
Author

JF-D commented Apr 30, 2024

I checked the version of transformers, the latest commit is the same with you tried (6b1fe691bf8c34318f1beb5124db1162d93f047e).

@JF-D
Copy link
Author

JF-D commented Apr 30, 2024

I find the error. When trying to quantize the weights, DS found the tensor is on meta device instead of GPU, so the tensor is not quantized (here).

But I think I should be able to run arctic model with FP8 quantization and 8x80GB A100. It's quite strange. Maybe something wrong with huggingface accelerate?

@JF-D
Copy link
Author

JF-D commented Apr 30, 2024

I guess I find the reason. The transformers cannot get aware the deepspeed quantization config, so it gives a wrong auto placement with accelerate (here).

@sfc-gh-reyazda
Copy link

how about explicitly specifying it:



quant_config = QuantizationConfig(q_bits=8)

model = AutoModelForCausalLM.from_pretrained(
    "/checkpoint/2b-v30",
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map="auto",
    ds_quantization_config=quant_config,
    max_memory={i: "150GiB" for i in range(8)},
    torch_dtype=torch.bfloat16)

@JF-D
Copy link
Author

JF-D commented May 1, 2024

I have set the config as the following,

quant_config = QuantizationConfig(q_bits=8)

model = AutoModelForCausalLM.from_pretrained(
    "Snowflake/snowflake-arctic-instruct",
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map="auto",
    ds_quantization_config=quant_config,
    max_memory={i: "80GiB" for i in range(8)},
    torch_dtype=torch.bfloat16)

The transformers cannot capture the quantization config set by ds_quantization_config. Notably, I am using 80GB A100, so I set the max_memory to 80GB. This leads to a wrong mapping. I can run the example by setting max_memory to 160GB to mimic the quantization effect.

@jeffra
Copy link
Collaborator

jeffra commented May 1, 2024

Ohh yes, you have to set the max_memory to ~2x the actual memory available so that accelerate will do the right thing. To confirm, you are running successfully now after making this change right?

We are actively working on adding deepspeed quantization support into HFQuantizer instead of this current way. This should smooth out this path once it's live.

@JF-D
Copy link
Author

JF-D commented May 2, 2024

Yes! I can run successfully after setting max_memory to ~2x the actual memory available. Thanks for the help!

@jeffra
Copy link
Collaborator

jeffra commented May 2, 2024

Excellent, glad to hear :) I'll close this for now then, please re-open if there are remaining issues though.

@jeffra jeffra closed this as completed May 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants