Skip to content

Commit 32e6302

Browse files
oyilmaz-nvidiaabharwanipre-commit-ci[bot]
authored
Updates for TRT-LLM 0.9 (#8873)
* upgrade to trtllm0.9 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update gpt to config based export Signed-off-by: Onur Yilmaz <[email protected]> * fix for lora checkpoint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix for in flight batching case * Update falcon for trt-llm 0.9 Signed-off-by: Onur Yilmaz <[email protected]> * Removed unused import and comment Signed-off-by: Onur Yilmaz <[email protected]> --------- Signed-off-by: Onur Yilmaz <[email protected]> Co-authored-by: abharwani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e9d8266 commit 32e6302

File tree

6 files changed

+44
-41
lines changed

6 files changed

+44
-41
lines changed

nemo/export/trt_llm/decoder/falcon.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
from tensorrt_llm.functional import non_gated_version
1919
from tensorrt_llm.models.falcon.model import FalconDecoderLayer
20-
from tensorrt_llm.models.modeling_utils import PretrainedConfig
21-
from tensorrt_llm.quantization import QuantMode
20+
from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig
2221
from typing_extensions import override
2322

2423
from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder
@@ -119,8 +118,7 @@ def build_decoder(self, layer):
119118
world_size=self.tensor_parallel,
120119
tp_size=self.tensor_parallel,
121120
pp_size=1,
122-
quant_mode=QuantMode(0),
123-
quant_kwargs=None,
121+
quantization=QuantConfig(),
124122
max_lora_rank=layer.max_lora_rank,
125123
use_parallel_embedding=False,
126124
)

nemo/export/trt_llm/decoder/gpt.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from tensorrt_llm.layers import AttentionMaskType, PositionEmbeddingType
1919
from tensorrt_llm.models.gpt.model import GPTDecoderLayer
20+
from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig
2021
from typing_extensions import override
2122

2223
from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder
@@ -85,37 +86,44 @@ class GPTDecoderLayerBuilder(DecoderLayerBuilder):
8586
@override
8687
def build_decoder(self, layer):
8788
rotary_pct = layer.rotary_pct
88-
position_embedding_type = (
89-
PositionEmbeddingType.rope_gpt_neox
90-
if layer.position_embedding_type == "rope"
91-
else PositionEmbeddingType.learned_absolute
92-
)
9389

94-
assert not (position_embedding_type == PositionEmbeddingType.rope_gpt_neox and rotary_pct == 0.0)
90+
position_embedding_type = "rope_gpt_neox" if layer.position_embedding_type == "rope" else "learned_absolute"
91+
92+
assert not (position_embedding_type == "rope_gpt_neox" and rotary_pct == 0.0)
9593

9694
bias_qkv = layer.attention.qkv.bias is not None
9795

9896
rotary_scaling = None
9997
if layer.rotary_scaling is not None:
10098
rotary_scaling = {"type": "linear", "factor": float(layer.rotary_scaling)}
10199

102-
return GPTDecoderLayer(
100+
config = PretrainedConfig(
101+
architecture=None,
102+
dtype=self.dtype,
103+
logits_dtype=self.dtype,
104+
vocab_size=layer.vocab_size,
105+
max_position_embeddings=self.max_position_embeddings,
103106
hidden_size=self.hidden_size,
107+
num_hidden_layers=self.num_layers,
104108
num_attention_heads=self.num_attention_heads,
105-
max_position_embeddings=self.max_position_embeddings,
106-
num_layers=self.num_layers,
107-
dtype=self.dtype,
108-
apply_query_key_layer_scaling=False,
109-
attention_mask_type=AttentionMaskType.causal,
109+
num_key_value_heads=self.num_kv_heads,
110110
hidden_act=self.hidden_act,
111+
intermediate_size=layer.ffn_hidden_size_local * self.tensor_parallel,
112+
norm_epsilon=layer.norm_epsilon,
111113
position_embedding_type=position_embedding_type,
112-
rotary_embedding_percentage=rotary_pct,
113-
rotary_base=layer.rotary_base,
114-
rotary_scaling=rotary_scaling,
115-
inter_size=layer.ffn_hidden_size_local * self.tensor_parallel,
116-
bias=bias_qkv,
117-
num_kv_heads=self.num_kv_heads,
118-
tp_group=self.tp_group,
114+
world_size=self.tensor_parallel,
119115
tp_size=self.tensor_parallel,
116+
pp_size=1,
120117
max_lora_rank=layer.max_lora_rank,
118+
quantization=QuantConfig(),
121119
)
120+
121+
config.set_if_not_exist('hidden_act', self.hidden_act)
122+
config.set_if_not_exist('apply_query_key_layer_scaling', False)
123+
config.set_if_not_exist('bias', bias_qkv)
124+
config.set_if_not_exist('rotary_base', layer.rotary_base)
125+
config.set_if_not_exist('rotary_scaling', rotary_scaling)
126+
config.set_if_not_exist('rotary_pct', rotary_pct)
127+
config.set_if_not_exist('moe_num_experts', 0)
128+
129+
return GPTDecoderLayer(config=config, layer_idx=self.layer_id,)

nemo/export/trt_llm/decoder/llama.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from tensorrt_llm.functional import non_gated_version
1919
from tensorrt_llm.layers import MoeConfig
2020
from tensorrt_llm.models.llama.model import LLaMADecoderLayer
21-
from tensorrt_llm.models.modeling_utils import PretrainedConfig
22-
from tensorrt_llm.quantization import QuantMode
21+
from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig
2322
from typing_extensions import override
2423

2524
from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder
@@ -118,9 +117,8 @@ def build_decoder(self, layer):
118117
world_size=self.tensor_parallel,
119118
tp_size=self.tensor_parallel,
120119
pp_size=1,
121-
quant_mode=QuantMode(0),
122-
quant_kwargs=None,
123120
max_lora_rank=layer.max_lora_rank,
121+
quantization=QuantConfig(),
124122
)
125123

126124
config.set_if_not_exist('mlp_bias', False)

nemo/export/trt_llm/tensorrt_llm_build.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensorrt_llm._utils import np_dtype_to_trt
2828
from tensorrt_llm.builder import Builder
2929
from tensorrt_llm.logger import logger
30+
from tensorrt_llm.models.modeling_utils import add_lora
3031
from tensorrt_llm.network import net_guard
3132
from tensorrt_llm.plugin.plugin import ContextFMHAType
3233
from tensorrt_llm.quantization import QuantMode
@@ -170,6 +171,9 @@ def _build_impl(tensorrt_llm_model, args):
170171
timing_cache_file = args.timing_cache if args.timing_cache else args.output_dir / "model.cache"
171172
timing_cache = timing_cache_file
172173

174+
if args.use_lora_plugin is not None:
175+
add_lora(tensorrt_llm_model, args.max_lora_rank)
176+
173177
builder = Builder()
174178
apply_query_key_layer_scaling = False
175179

nemo/export/trt_llm/tensorrt_llm_model.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,7 @@ def forward(
144144
if attention_mask is not None:
145145
attention_mask = expand_mask(attention_mask, shape(input_ids, -1))
146146

147-
for layer_idx, (layer, past, pointer, host_pointer, max_attention_window_size) in enumerate(
148-
zip(
149-
self.layers,
150-
kv_cache_params.past_key_value,
151-
kv_cache_params.kv_cache_block_pointers,
152-
kv_cache_params.host_kv_cache_block_pointers,
153-
kv_cache_params.host_max_attention_window_sizes,
154-
)
155-
):
147+
for layer_idx, (layer, past) in enumerate(zip(self.layers, kv_cache_params.past_key_value,)):
156148

157149
decoder_params = {
158150
"hidden_states": hidden_states,
@@ -161,8 +153,8 @@ def forward(
161153
"kv_cache_params": KeyValueCacheParams(
162154
past_key_value=[past],
163155
host_past_key_value_lengths=kv_cache_params.host_past_key_value_lengths,
164-
kv_cache_block_pointers=[pointer],
165-
host_max_attention_window_sizes=max_attention_window_size,
156+
kv_cache_block_pointers=kv_cache_params.kv_cache_block_pointers,
157+
host_max_attention_window_sizes=kv_cache_params.host_max_attention_window_sizes,
166158
cache_indirection=kv_cache_params.cache_indirection,
167159
host_sink_token_length=kv_cache_params.host_sink_token_length,
168160
host_kv_cache_block_pointers=kv_cache_params.host_kv_cache_block_pointers,
@@ -329,8 +321,8 @@ def prepare_inputs(
329321
past_key_value=model_inputs['past_key_value'],
330322
host_past_key_value_lengths=model_inputs['host_past_key_value_lengths'],
331323
host_max_attention_window_sizes=model_inputs['host_max_attention_window_sizes'],
332-
kv_cache_block_pointers=model_inputs['kv_cache_block_pointers_list'],
333-
host_kv_cache_block_pointers=model_inputs['host_kv_cache_block_pointers_list'],
324+
kv_cache_block_pointers=model_inputs['kv_cache_block_pointers'],
325+
host_kv_cache_block_pointers=model_inputs['host_kv_cache_block_pointers'],
334326
cache_indirection=model_inputs['cache_indirection'],
335327
host_sink_token_length=model_inputs['host_sink_token_length'],
336328
),

nemo/export/trt_llm/tensorrt_llm_run.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
import torch
2525
from mpi4py.futures import MPIPoolExecutor
2626
from tensorrt_llm.logger import logger
27+
from tensorrt_llm.lora_manager import LoraManager
2728
from tensorrt_llm.quantization import QuantMode
28-
from tensorrt_llm.runtime import LoraManager, ModelConfig, SamplingConfig
29+
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
2930
from transformers import PreTrainedTokenizer
3031

3132
from nemo.export.trt_llm.tensor_utils import get_tensor_parallel_group
3233
from nemo.export.trt_llm.tensorrt_llm_model import LMHeadModelBuilder
34+
3335
from nemo.export.trt_llm.tensorrt_llm_build import get_engine_name, MODEL_NAME, refit_runtime_engine # isort:skip
3436
from nemo.export.trt_llm.nemo_utils import to_word_list_format # isort:skip
3537

@@ -90,6 +92,7 @@ def _read_config(config_path: Path):
9092
model_config = ModelConfig(
9193
model_name=config["builder_config"]["name"],
9294
max_batch_size=config["builder_config"]["max_batch_size"],
95+
max_beam_width=config["builder_config"]["max_beam_width"],
9396
vocab_size=config["builder_config"]["vocab_size"],
9497
num_layers=config["builder_config"]["num_layers"],
9598
num_heads=num_heads,

0 commit comments

Comments
 (0)