YOLO-NAS-Pose-JetPack5 / yolo_nas_pose_to_onnx.py
Luigi's picture
Regenerate ONNX models
b8d8df3
#! /usr/bin/python3
from termcolor import cprint, colored
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.conversion import ExportTargetBackend, ExportQuantizationMode, DetectionOutputFormatMode
import time
import cv2
import numpy as np
from super_gradients.training.utils.media.image import load_image
import onnxruntime
import os
from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization
import matplotlib.pyplot as plt
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
os.environ['CRASH_HANDLER']='0'
# Conversion Setting
CONVERSION = True
input_image_shape = [640, 640]
quantization_modes = [ExportQuantizationMode.INT8, ExportQuantizationMode.FP16, None]
output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT
# NMS-related Setting
confidence_threshold=.15
nms_threshold=.2
num_pre_nms_predictions=1000
max_predictions_per_image=10
# ONNXruntime Benchmark Setting
BENCHMARK=True
n_run = 1000
n_warm_up = 200
image_name = "https://deci-pretrained-models.s3.amazonaws.com/sample_images/beatles-abbeyroad.jpg"
# Check
SHAPE_CHECK=True
VISUAL_CHECK=True
CALIBRATION_DATASET_CHECK=False
# Function to convert tensor to image for visualization
def tensor_to_image(tensor):
# Convert the tensor to a numpy array
numpy_image = tensor.numpy()
# The output of ToTensor() is in C x H x W format, convert to H x W x C
numpy_image = numpy_image.transpose(1, 2, 0)
# Undo the normalization (if any)
# numpy_image = numpy_image * std + mean # Adjust based on your normalization
return numpy_image
class HFDatasetWrapper(Dataset):
def __init__(self, hf_dataset, transform=None):
self.hf_dataset = hf_dataset
self.transform = transform
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
item = self.hf_dataset[idx]
if self.transform:
item = self.transform(item)
return item['image']
def preprocess(data):
# Convert byte data to PIL Image
image = data['image']
# Convert to RGB if not already
if image.mode != 'RGB':
image = image.convert('RGB')
# Define your transformations
transform = transforms.Compose([
transforms.Resize((640, 640)), # Resize (example size)
transforms.ToTensor(), # Convert to tensor
# Add normalization or other transformations if needed
])
# Process Image
transformed = transform(image)
if CALIBRATION_DATASET_CHECK:
# Display the Processed Image
plt_image = tensor_to_image(transformed)
plt.imshow(plt_image)
plt.axis('off') # Turn off axis numbers
plt.show()
return {'image': transformed}
def iterate_over_flat_predictions(predictions, batch_size):
[flat_predictions] = predictions
for image_index in range(batch_size):
mask = flat_predictions[:, 0] == image_index
pred_bboxes = flat_predictions[mask, 1:5]
pred_scores = flat_predictions[mask, 5]
pred_joints = flat_predictions[mask, 6:].reshape((len(pred_bboxes), -1, 3))
yield image_index, pred_bboxes, pred_scores, pred_joints
def show_predictions_from_flat_format(image, predictions):
image_index, pred_boxes, pred_scores, pred_joints = next(iter(iterate_over_flat_predictions(predictions, 1)))
image = PoseVisualization.draw_poses(
image=image, poses=pred_joints, scores=pred_scores, boxes=pred_boxes,
edge_links=None, edge_colors=None, keypoint_colors=None, is_crowd=None
)
plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.tight_layout()
plt.show()
image = load_image(image_name)
image = cv2.resize(image, (input_image_shape[1], input_image_shape[0]))
image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))
# Prepare Calibration Dataset for INT8 Quantization
dataset = load_dataset("cppe-5", split="train")
hf_dataset_wrapper = HFDatasetWrapper(dataset, transform=preprocess)
calibration_loader = DataLoader(hf_dataset_wrapper, batch_size=8)
for model_name in [Models.YOLO_NAS_POSE_L, Models.YOLO_NAS_POSE_M, Models.YOLO_NAS_POSE_N, Models.YOLO_NAS_POSE_S ]:
for q in quantization_modes:
# Specify Quantization Mode in Exported ONNX Model Name
if q == None:
q_label = 'fp32'
elif q == ExportQuantizationMode.INT8:
q_label = 'int8'
elif q == ExportQuantizationMode.FP16:
q_label = 'fp16'
else:
raise
export_name = f"{model_name}_{q_label}.onnx"
# Perform Model Conversion from PyTorch to ONNX using Super-Gradiant Official Method
print(f"1. Convert {colored(model_name,'blue')} from PyTorch to ONNX format using {colored(q_label,'red')} precision, saved as {colored(export_name,'green')}")
if CONVERSION:
model = models.get(model_name, pretrained_weights="coco_pose")
export_result = model.export(
output=export_name,
confidence_threshold=confidence_threshold,
nms_threshold=nms_threshold,
engine=ExportTargetBackend.ONNXRUNTIME,
quantization_mode=q,
#selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa
calibration_loader = calibration_loader if q == ExportQuantizationMode.INT8 else None,
#calibration_method: str = "percentile",
#calibration_batches: int = 16,
#calibration_percentile: float = 99.99,
preprocessing=True,
postprocessing=True,
#postprocessing_kwargs: Optional[dict] = None,
batch_size=1,
input_image_shape=input_image_shape,
#input_image_channels: Optional[int] = None,
#input_image_dtype: Optional[torch.dtype] = None,
max_predictions_per_image=max_predictions_per_image,
onnx_export_kwargs={"opset_version":14},
onnx_simplify=True,
#device: Optional[Union[torch.device, str]] = None,
output_predictions_format=output_predictions_format,
num_pre_nms_predictions=num_pre_nms_predictions,
)
# Export Also Model Usage in Text
usage_name = export_name + '.usage.txt'
with open(usage_name, 'w') as f:
f.write(str(export_result))
print(f"1.1 Related usage to {colored(export_name, 'green')} has been stored to {colored(usage_name,'yellow')}")
if BENCHMARK:
# Perform Inference on ONNXruntime
session = onnxruntime.InferenceSession(export_name, providers=['CUDAExecutionProvider',"CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
# Detection Result Shape
for i in range(n_warm_up): result = session.run(outputs, {inputs[0]: image_bchw})
t=time.time()
for i in range(n_run): result = session.run(outputs, {inputs[0]: image_bchw})
latency=(time.time()-t)/n_run
fps = round(1/latency,2)
print(f'2. Averaged FPS: {colored(fps, "red")}')
if SHAPE_CHECK:
for image_index, pred_bboxes, pred_scores, pred_joints in iterate_over_flat_predictions(result, batch_size=1):
N = pred_scores.shape[0]
for i in range(N):
print(f'Detected Object {colored(i,"green")}')
print(f'Predicted Bounding Box (Dimension: 1 x 4)', pred_bboxes[i,:])
print(f'Pose Confidence (scalar)', pred_scores[i])
print(f'Predicted Joints (Dimension: 3 x 17)', pred_joints[i,:,:])
if VISUAL_CHECK:
# Detection Result Visual Check
show_predictions_from_flat_format(image, result)