Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the quantized data shape compatible with original tensor shape #5483

Open
wants to merge 47 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
085fcd8
Make the quantized data shape compatible with original tensor shape
sfc-gh-reyazda Apr 30, 2024
a83b384
change the scale and quantized data format
sfc-gh-reyazda May 4, 2024
048648d
minor fixes
sfc-gh-reyazda May 8, 2024
bf12893
fix
sfc-gh-reyazda May 15, 2024
b18f71f
minor fix
sfc-gh-reyazda May 15, 2024
4d6e04b
Merge branch 'master' into fix-quantized-shape
sfc-gh-reyazda May 15, 2024
f924455
more fixed
sfc-gh-reyazda Jun 9, 2024
e03c0f4
Merge branch 'fix-quantized-shape' of https://github.com/Snowflake-La…
sfc-gh-reyazda Jun 9, 2024
d9cfba6
Improve _configure_optimizer() final optimizer log (#5528)
nelyahu May 15, 2024
2bbc680
Enhance testing: Skip fused_optimizer tests if not supported. (#5159)
vshekhawat-hlab May 16, 2024
b3ab626
Skip the UT cases that use unimplemented op builders. (#5372)
foin6 May 16, 2024
4494c86
rocblas -> hipblas changes for ROCm (#5401)
rraminen May 17, 2024
2c0dcac
Rocm warp size fix (#5402)
rraminen May 17, 2024
f53895f
Optimize zero3 fetch params using all_reduce (#5420)
deepcharm May 20, 2024
bb146c3
CPUAdam fp16 and bf16 support (#5409)
BacharL May 20, 2024
31f11c0
Fix the TypeError for XPU Accelerator (#5531)
shiyang-weng May 20, 2024
35b4813
Fix RuntimeError for moe on XPU: tensors found at least two devices (…
shiyang-weng May 21, 2024
cf0ccb5
Remove synchronize calls from allgather params (#5516)
BacharL May 21, 2024
e388056
Avoid overwrite of compiled module wrapper attributes (#5549)
deepcharm May 21, 2024
5ff0d44
Small typos in functions set_none_gradients_to_zero (#5557)
TravelLeraLone May 21, 2024
29ab009
Adapt doc for #4405 (#5552)
oraluben May 21, 2024
633da3d
Update to HF_HOME from TRANSFORMERS_CACHE (#4816)
loadams May 22, 2024
9db010e
[INF] DSAttention allow input_mask to have false as value (#5546)
oelayan7 May 22, 2024
bd2b2ef
Add throughput timer configuration (#5363)
deepcharm May 22, 2024
3c5aa00
Add Ulysses DistributedAttention compatibility (#5525)
Kwen-Chen May 22, 2024
d7f9be6
Add hybrid_engine.py as path to trigger the DS-Chat GH workflow (#5562)
lekurile May 23, 2024
c160d76
Update HPU docker version (#5566)
loadams May 28, 2024
c203830
[MiCS] Remove the handle print on DeepSpeed side (#5574)
ys950902 May 28, 2024
5e5c8a7
Rename files in fp_quantize op from quantize.* to fp_quantize.* (#5577)
loadams May 28, 2024
ff01ade
Update to fix sidebar over text (#5567)
loadams May 28, 2024
83920f6
DeepSpeedCheckpoint: support custom final ln idx (#5506)
nelyahu May 28, 2024
a6076cf
Update minor CUDA version compatibility (#5591)
adk9 May 31, 2024
9db9970
Add slide deck for meetup in Japan (#5598)
tohtana May 31, 2024
c6f151c
Fixed the Windows build. (#5596)
costin-eseanu May 31, 2024
0bf3511
estimate_zero2_model_states_mem_needs: fixing memory estiamtion (#5099)
nelyahu Jun 4, 2024
cca53b0
Fix cuda hardcode for inference woq (#5565)
Liangliang-Ma Jun 5, 2024
31815d9
fix sequence parallel(Ulysses) grad scale for zero0 (#5555)
inkcherry Jun 5, 2024
6ad125e
Add Compressedbackend for Onebit optimizers (#5473)
Liangliang-Ma Jun 5, 2024
9c15b8f
Updated hpu-gaudi2 tests content. (#5622)
vshekhawat-hlab Jun 6, 2024
2e4bc1d
Pin transformers version for MII tests (#5629)
loadams Jun 7, 2024
e5b4d41
WA for Torch-compile-Z3-act-apt accuracy issue from the Pytorch repo …
NirSonnenschein Jun 10, 2024
8a4d03c
stage_1_and_2: optimize clip calculation to use clamp (#5632)
nelyahu Jun 10, 2024
5e5b1f7
Fix overlap communication of ZeRO stage 1 and 2 (#5606)
penn513 Jun 10, 2024
c47ad5f
Merge branch 'master' of https://github.com/Snowflake-Labs/deepspeed …
sfc-gh-reyazda Jun 10, 2024
277902a
remove float8 dtype
sfc-gh-reyazda Jun 10, 2024
74311af
Merge branch 'master' into fix-quantized-shape
sfc-gh-reyazda Jul 11, 2024
9eb12fb
Merge branch 'master' into fix-quantized-shape
sfc-gh-reyazda Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions csrc/fp_quantizer/fp_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ at::Tensor quantize(torch::Tensor& out,
int q_bits,
int q_mantisa_bits)
{
int total_elems = at::numel(val);
size_t total_elems = at::numel(val);
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges
(q_bits == 12 ? 510.0 : // fp12 range
(q_bits == 6 ? 28.0 : // fp6 range
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4
// in case accuracy is not matching!
int num_groups = total_elems / group_size;
size_t num_groups = total_elems / group_size;

DISPATCH_QUANTIZE(kHalf, __half, 23, 8);
#ifdef BF16_AVAILABLE
Expand All @@ -45,6 +45,18 @@ at::Tensor quantize(torch::Tensor& out,
return out;
}

at::Tensor get_scales(torch::Tensor& out, int num_groups)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is redefined at line 118.

{
auto options = at::TensorOptions()
.dtype(torch::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto scales =
torch::from_blob(out.data_ptr(), {num_groups, 1}, {out.stride(0) / 4, 1}, options);
return scales;
}

#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \
if (val.options().dtype() == torch::T_TYPE) { \
launch_dequantization<C_TYPE, mantisa>((uint8_t*)val_q.data_ptr(), \
Expand All @@ -63,9 +75,9 @@ void dequantize(torch::Tensor& val,
int q_mantisa_bits,
int q_exponent_bits)
{
int total_elems = at::numel(val);
size_t total_elems = at::numel(val);

int num_groups = total_elems / group_size;
size_t num_groups = total_elems / group_size;

DISPATCH_DEQUANTIZE(kHalf, __half, 10);
#ifdef BF16_AVAILABLE
Expand Down Expand Up @@ -93,9 +105,9 @@ void selective_dequantize(torch::Tensor& val,
int q_mantisa_bits,
int q_exponent_bits)
{
int total_elems = at::numel(val);
size_t total_elems = at::numel(val);
int num_indexes = indexes.size(0);
int num_groups = total_elems / group_size;
size_t num_groups = total_elems / group_size;

DISPATCH_DEQUANTIZE_INDEX(kHalf, __half, 10);
#ifdef BF16_AVAILABLE
Expand Down
44 changes: 22 additions & 22 deletions csrc/fp_quantizer/fp_quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ __global__ void apply_quantization(T* val,
std::pair<uint64_t, uint64_t> seed,
float q_range)
{
int tidx = threadIdx.x;
int wid = tidx >> 5;
int lane = tidx & 0x1f;
int gid = blockIdx.x * quantization::warps + wid;
unsigned int tidx = threadIdx.x;
unsigned int wid = tidx >> 5;
unsigned int lane = tidx & 0x1f;
unsigned int gid = blockIdx.x * quantization::warps + wid;

constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1;
Expand All @@ -98,7 +98,7 @@ __global__ void apply_quantization(T* val,
T cur_max;
reduce::init<ROp::Max>(&cur_max);

int idx = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);

Expand Down Expand Up @@ -228,7 +228,7 @@ template <typename T,
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements)
{
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size;
unsigned int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size;

constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
Expand Down Expand Up @@ -308,7 +308,6 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int
if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;

q_buf[j] =
((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
Expand All @@ -334,7 +333,7 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int
template <typename T, int mantisa, int exponent>
void launch_quantization(T* val,
uint8_t* q_val,
int num_groups,
size_t num_groups,
int group_size,
cudaStream_t stream,
float q_range,
Expand All @@ -344,12 +343,12 @@ void launch_quantization(T* val,
{
const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
const dim3 block(quantization::threads);

std::pair<uint64_t, uint64_t> seed = FPContext::Instance().IncrementOffset(16);

constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);

const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;

QUANT_SWITCH((q_bits - q_mantisa_bits - 1) * q_mantisa_bits + stochastic_rounding, [&] {
switch (copy_unroll) {
LAUNCH_FOR_QUANTIZATION_UNROLL(1)
Expand All @@ -363,7 +362,7 @@ void launch_quantization(T* val,
}
#define INSTANTIATE_LAUNCH_QUANTIZATION(T, mantisa, exponent) \
template void launch_quantization<T, mantisa, exponent>( \
T*, uint8_t*, int, int, cudaStream_t, float q_range, int, int, int);
T*, uint8_t*, size_t, int, cudaStream_t, float q_range, int, int, int);
// fp8(E4M3), nearest-rounding
#ifdef BF16_AVAILABLE
INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8);
Expand All @@ -373,7 +372,7 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8);
template <typename T, int mantisa>
void launch_dequantization(uint8_t* val,
T* q_val,
int num_groups,
size_t num_groups,
int group_size,
int q_mantisa_bits,
int q_exponent_bits,
Expand All @@ -390,7 +389,8 @@ void launch_dequantization(uint8_t* val,
});
}
#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \
template void launch_dequantization<T, mantisa>(uint8_t*, T*, int, int, int, int, cudaStream_t);
template void launch_dequantization<T, mantisa>( \
uint8_t*, T*, size_t, int, int, int, cudaStream_t);
// fp8(E4M3)
#ifdef BF16_AVAILABLE
INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7);
Expand All @@ -406,12 +406,12 @@ __global__ void apply_selective_dequantization(uint8_t* val,
T* q_val,
int32_t* indexes,
int group_size,
int total_num_elements)
size_t total_num_elements)
{
int index = indexes[blockIdx.x];
unsigned int index = indexes[blockIdx.x];
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size;
int input_index = index * total_num_elements + tidx;
unsigned int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size;
unsigned int input_index = index * total_num_elements + tidx;
constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
Expand Down Expand Up @@ -504,17 +504,17 @@ template <typename T, int mantisa>
void launch_selective_dequantization(uint8_t* val,
T* q_val,
int32_t* indexes,
int num_groups,
size_t num_groups,
int group_size,
int num_indexes,
int q_mantisa_bits,
int q_exponent_bits,
cudaStream_t stream)
{
int total_elements_per_index = (num_groups / num_indexes) * group_size;
int blocks = (total_elements_per_index - 1) /
(quantization::threads * (quantization::access_granularity / sizeof(T))) +
1;
size_t total_elements_per_index = (num_groups / num_indexes) * group_size;
size_t blocks = (total_elements_per_index - 1) /
(quantization::threads * (quantization::access_granularity / sizeof(T))) +
1;
const dim3 grid(num_indexes, blocks);
const dim3 block(quantization::threads);
DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
Expand All @@ -524,7 +524,7 @@ void launch_selective_dequantization(uint8_t* val,
}
#define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \
template void launch_selective_dequantization<T, mantisa>( \
uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t);
uint8_t*, T*, int32_t*, size_t, int, int, int, int, cudaStream_t);
// fp8(E4M3)
#ifdef BF16_AVAILABLE
INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7);
Expand Down
6 changes: 3 additions & 3 deletions csrc/fp_quantizer/includes/fp_quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
template <typename T, int mantisa, int exponent>
void launch_quantization(T* val,
uint8_t* q_val,
int num_groups,
size_t num_groups,
int group_size,
cudaStream_t stream,
float q_range,
Expand All @@ -110,7 +110,7 @@ void launch_quantization(T* val,
template <typename T, int mantisa>
void launch_dequantization(uint8_t* val,
T* q_val,
int num_groups,
size_t num_groups,
int group_size,
int q_mantisa_bits,
int q_exponent_bits,
Expand All @@ -120,7 +120,7 @@ template <typename T, int mantisa>
void launch_selective_dequantization(uint8_t* val,
T* q_val,
int32_t* indexes,
int num_groups,
size_t num_groups,
int group_size,
int num_indexes,
int q_mantisa_bits,
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/linear/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def _ensure_quantized(self, tensor: torch.Tensor):
tensor.data = self.quantizer.quantize(tensor.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
assert tensor.dtype == torch.uint8
assert (tensor.dtype == torch.int8), \
Copy link
Contributor

@hwchen2017 hwchen2017 Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it be torch.uint8 instead?

f"Quantize conversion dtype ({tensor.dtype}) error!"

def dequantized(self) -> torch.Tensor:
"""
Expand Down
1 change: 1 addition & 0 deletions deepspeed/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig

from ..git_version_info import compatible_ops as __compatible_ops__
from . import fp_quantizer
37 changes: 17 additions & 20 deletions deepspeed/ops/fp_quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def quantize(self,
q_bits=8,
q_mantisa_bits=3,
stochastic_mode=False,
return_meta_tensor=False) -> torch.Tensor:
assert input.dtype == torch.bfloat16, "only support bf16 for now"
return_meta_tensor=False,
out=None) -> torch.Tensor:
assert input.dtype == torch.bfloat16, f"only support bf16 for now, dtype is {input.dtype}"
if return_meta_tensor:
assert q_bits == 8, "meta tensor is only supported with q_bit=8"

Expand All @@ -73,23 +74,23 @@ def quantize(self,
else:
assert (0), \
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"

self.num_groups = input.numel() // self.group_size
self.input_q = torch.ones(self.num_groups,
int(self.group_size * q_bits) // 8 + 4,
dtype=torch.uint8,
device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
self.input_q = torch.ones(
self.num_groups, int(self.group_size * q_bits) // 8 +
4, dtype=torch.uint8, device=input.device) if out is None else out
input_q_reshaped = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits,
q_mantisa_bits)
if return_meta_tensor:
data, self.scale = out.split(self.group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
self.scale = self.scale.contiguous()
self.scales = input_q_reshaped[:, -4:].contiguous().reshape(-1, 4)
input_q_reshaped = self.input_q[:, :-4].contiguous().reshape(self.orig_shape)
del self.input_q
del out
gc.collect()
get_accelerator().empty_cache()
return data, self.scale
self.input_q = None
return input_q_reshaped, self.scales
return input_q_reshaped

return out
def get_scales(self):
return fp_quant_module.get_scales(self.scales, self.num_groups)

def to(self, *args, **kwargs):
# Intermediate tensors may need to be moved to different devices
Expand Down Expand Up @@ -123,6 +124,7 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)

return fp_out

def selective_dequantize(self,
Expand Down Expand Up @@ -151,11 +153,6 @@ def selective_dequantize(self,
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()

fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out
Loading