Skip to content

Commit 19f6d68

Browse files
authored
Fix batch infer for gemma3vl (#3592)
* fix gemma3vl * upgrade to torch2.6
1 parent 90f3209 commit 19f6d68

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

.github/workflows/pr_ete_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ jobs:
5858
run: |
5959
python3 -m pip cache dir
6060
python3 -m pip install --upgrade pip setuptools==69.5.1
61-
python3 -m pip install torch==2.5.1 torchvision==0.20.1
61+
python3 -m pip install torch==2.6.0 torchvision==0.21.0
6262
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
63-
python3 -m pip install /root/packages/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
63+
python3 -m pip install /root/packages/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
6464
- name: Build lmdeploy
6565
run: |
6666
cp /nvme/qa_test_models/offline_pkg/openmpi-4.1.5.tar.gz .

lmdeploy/vl/model/gemma3_vl.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,20 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
7474
)
7575
images = self.collect_images(messages)
7676
images = [image.convert('RGB') for image, _ in images]
77+
num_image = len(images)
7778
images = make_nested_list_of_images(images)
7879
image_inputs = self.processor.image_processor(images, **output_kwargs['images_kwargs'])
79-
image_inputs['image_tokens'] = self.image_tokens
80-
image_inputs['image_token_id'] = self.image_token_id
81-
messages.append(dict(role='preprocess', content=[image_inputs]))
80+
outputs = []
81+
for idx in range(num_image):
82+
pixel_values = image_inputs['pixel_values'][idx:idx + 1, ...]
83+
num_crops = image_inputs['num_crops'][:idx:idx + 1]
84+
data = dict(pixel_values=pixel_values,
85+
num_crops=num_crops,
86+
image_tokens=self.image_tokens,
87+
image_token_id=self.image_token_id)
88+
outputs.append(data)
89+
90+
messages.append(dict(role='preprocess', content=outputs))
8291
return messages
8392

8493
@torch.no_grad()

0 commit comments

Comments
 (0)