You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I attempted to perform inference using OpenVINO on quantized versions of several models (with inputs from the ImageNet dataset). The inference works correctly on both the CPU and NPU, successfully recognizing different classes. However, when running on the iGPU, the output consistently predicts the same class, regardless of the input image. I tested this extensively on the full dataset and observed the same issue across various quantized models (e.g., MobileNet).
Upon investigation using the code below, I found that the logits output by the iGPU are extremely narrow in range, which likely explains why the predicted class remains constant.
Since the models are identical across processors, I’m trying to understand why the iGPU produces such abnormal results? Could it be that the iGPU has special requirements or limitations when executing quantized models in OpenVINO?
core = ov.Core()
model = core.read_model(model_path)
devices = ["CPU", "GPU", "NPU"]
results = {}
for device in devices:
try:
print(f"\n=== Testing on {device} ===")
compiled_model = core.compile_model(model, device)
infer_request = compiled_model.create_infer_request()
device_results = []
for idx, image_path in enumerate(test_images):
input_tensor = preprocess_image(image_path)
infer_request.set_input_tensor(ov.Tensor(input_tensor))
infer_request.infer()
output = infer_request.get_output_tensor(0).data[0]
predicted_class = np.argmax(output)
confidence = np.max(output)
device_results.append({
'class': predicted_class,
'confidence': confidence,
'logit_range': [np.min(output), np.max(output)]
})
print(f"Image {idx+1}: Class {predicted_class}, "
f"Confidence {confidence:.4f}, "
f"Logits [{np.min(output):.2f}, {np.max(output):.2f}]")
results[device] = device_results
except Exception as e:
print(f"Error on {device}: {e}")
continue
return results
def main():
model_path = "/home/code/multiDNN/MultiDNN_Models/resnet_v1-50-imagenet-pruned85.4block_quantized.onnx"
test_images = [
"/home/imageNet_dataset/ILSVRC/Data/CLS-LOC/val/n02116738/ILSVRC2012_val_00047535.JPEG",
"/home/imageNet_dataset/ILSVRC/Data/CLS-LOC/val/n01641577/ILSVRC2012_val_00025633.JPEG",
"/home/imageNet_dataset/ILSVRC/Data/CLS-LOC/val/n03871628/ILSVRC2012_val_00048769.JPEG"
]
results = test_model_on_devices(model_path, test_images)
print("\n=== Summary ===")
for device, device_results in results.items():
classes = [r['class'] for r in device_results]
print(f"{device}: Predicted classes {classes}")
if len(set(classes)) == 1:
print(f" WARNING: All images classified as class {classes[0]}!")
if name == "main":
main()
Below are tested images:
Relevant log output
=== Testing on CPU ===
Image 1: Class 275, Confidence 17.3129, Logits [-8.89, 17.31]
Image 2: Class 30, Confidence 10.9580, Logits [-7.98, 10.96]
Image 3: Class 692, Confidence 16.4645, Logits [-10.47, 16.46]
=== Testing on GPU ===
Image 1: Class 623, Confidence 0.0654, Logits [-0.05, 0.07]
Image 2: Class 623, Confidence 0.0654, Logits [-0.05, 0.07]
Image 3: Class 623, Confidence 0.0654, Logits [-0.05, 0.07]
=== Testing on NPU ===
Image 1: Class 275, Confidence 17.0781, Logits [-8.93, 17.08]
Image 2: Class 30, Confidence 10.3594, Logits [-8.20, 10.36]
Image 3: Class 692, Confidence 16.5312, Logits [-10.34, 16.53]
=== Summary ===
CPU: Predicted classes [275, 30, 692]
GPU: Predicted classes [623, 623, 623]
WARNING: All images classified as class 623!
NPU: Predicted classes [275, 30, 692]
Issue submission checklist
I'm reporting an issue. It's not a question.
I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
There is reproducer code and related data files such as images, videos, models, etc.
The text was updated successfully, but these errors were encountered:
Hi @Jiawei888 , This is a known issue when running quantized ONNX models on the iGPU with OpenVINO. Even though the models work fine on CPU and NPU, the iGPU often struggles with certain quantization operations (like QuantizeLinear and DequantizeLinear) commonly used in SparseZoo. As a result, you get nearly identical predictions and a very narrow output range. Essentially, the quantization isn’t being handled properly on the iGPU side.
I recommend converting the model to OpenVINO’s IR format in FP16 using the Model Optimizer. This helps OpenVINO better understand and optimize the model for the iGPU.
mo
--input_model /path/to/model.onnx
--data_type FP16
--output_dir /path/to/output_dir
After the conversion, update your script to the generated .xml file instead of the ONNX model. In most cases, this resolves the issue and enables consistent predictions across CPU, GPU, and NPU.
OpenVINO Version
2025.1.0-18503-6fec06580ab-releases/2025/1
Operating System
Other (Please specify in description)
Device used for inference
GPU
Framework
ONNX
Model used
https://sparsezoo.neuralmagic.com/models/resnet_v1-50-imagenet-pruned85.4block_quantized?hardware=deepsparse-c6i.12xlarge&comparison=resnet_v1-50-imagenet-base
Issue description
OpenVINO version: 2025.1.0-18503-6fec06580ab-releases/2025/1
Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0]
Desktop: Intel Ultra 7 265K
OS: Ubuntu 24.04
[Linux NPU Driver v1.16.0]
OpenCL version: 25.13.33276.16
I attempted to perform inference using OpenVINO on quantized versions of several models (with inputs from the ImageNet dataset). The inference works correctly on both the CPU and NPU, successfully recognizing different classes. However, when running on the iGPU, the output consistently predicts the same class, regardless of the input image. I tested this extensively on the full dataset and observed the same issue across various quantized models (e.g., MobileNet).
Upon investigation using the code below, I found that the logits output by the iGPU are extremely narrow in range, which likely explains why the predicted class remains constant.
Since the models are identical across processors, I’m trying to understand why the iGPU produces such abnormal results? Could it be that the iGPU has special requirements or limitations when executing quantized models in OpenVINO?
Please help me analyze this based on the models I provided—I need to use these quantized models in my project.
You can also find quantized models on SparseZoo (https://sparsezoo.neuralmagic.com/?ungrouped=true&modelSet=computer_vision&datasets=imagenet&tasks=classification), such as ResNet50 and MobileNet, and they all exhibit the same issue on the iGPU.
Thank you very much!
mobilenet-quantized version:
https://sparsezoo.neuralmagic.com/models/mobilenet_v1-1.0-imagenet-pruned.4block_quantized?hardware=deepsparse-c6i.12xlarge&comparison=mobilenet_v1-1.0-imagenet-base
Resnet50-quantized verison:
https://sparsezoo.neuralmagic.com/models/resnet_v1-50-imagenet-pruned85.4block_quantized?hardware=deepsparse-c6i.12xlarge&comparison=resnet_v1-50-imagenet-base
Step-by-step reproduction
import numpy as np
import cv2
import openvino as ov
import sys
def preprocess_image(image_path):
def test_model_on_devices(model_path, test_images):
def main():
if name == "main":
main()
Below are tested images:
Relevant log output
=== Testing on CPU === Image 1: Class 275, Confidence 17.3129, Logits [-8.89, 17.31] Image 2: Class 30, Confidence 10.9580, Logits [-7.98, 10.96] Image 3: Class 692, Confidence 16.4645, Logits [-10.47, 16.46] === Testing on GPU === Image 1: Class 623, Confidence 0.0654, Logits [-0.05, 0.07] Image 2: Class 623, Confidence 0.0654, Logits [-0.05, 0.07] Image 3: Class 623, Confidence 0.0654, Logits [-0.05, 0.07] === Testing on NPU === Image 1: Class 275, Confidence 17.0781, Logits [-8.93, 17.08] Image 2: Class 30, Confidence 10.3594, Logits [-8.20, 10.36] Image 3: Class 692, Confidence 16.5312, Logits [-10.34, 16.53] === Summary === CPU: Predicted classes [275, 30, 692] GPU: Predicted classes [623, 623, 623] WARNING: All images classified as class 623! NPU: Predicted classes [275, 30, 692]
Issue submission checklist
The text was updated successfully, but these errors were encountered: