From 13de5160be26abed1c0237634a1e69f9bbe9e7df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Cr=C3=AAte?= Date: Wed, 24 Jan 2024 21:31:44 -0500 Subject: [PATCH] onnx: Add more tensor data types Part-of: --- .../ext/onnx/gstonnxclient.cpp | 18 ++++++++----- .../ext/onnx/gstonnxinference.cpp | 2 +- .../ext/onnx/tensor/gsttensormeta.h | 26 ++++++++++++++++--- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp index 98f21a66fe..37060dd543 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp @@ -53,7 +53,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren dest (nullptr), m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU), inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC), - inputDatatype (GST_TENSOR_TYPE_INT8), + inputDatatype (GST_TENSOR_TYPE_UINT8), inputDatatypeSize (sizeof (uint8_t)), fixedInputImageSize (false), inputTensorOffset (0.0), @@ -100,21 +100,27 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren { inputDatatype = datatype; switch (inputDatatype) { - case GST_TENSOR_TYPE_INT8: + case GST_TENSOR_TYPE_UINT8: inputDatatypeSize = sizeof (uint8_t); break; - case GST_TENSOR_TYPE_INT16: + case GST_TENSOR_TYPE_UINT16: inputDatatypeSize = sizeof (uint16_t); break; - case GST_TENSOR_TYPE_INT32: + case GST_TENSOR_TYPE_UINT32: inputDatatypeSize = sizeof (uint32_t); break; + case GST_TENSOR_TYPE_INT32: + inputDatatypeSize = sizeof (int32_t); + break; case GST_TENSOR_TYPE_FLOAT16: inputDatatypeSize = 2; break; case GST_TENSOR_TYPE_FLOAT32: inputDatatypeSize = sizeof (float); break; + default: + g_error ("Data type %d not handled", inputDatatype); + break; }; } @@ -241,7 +247,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren switch (elementType) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - setInputImageDatatype(GST_TENSOR_TYPE_INT8); + setInputImageDatatype(GST_TENSOR_TYPE_UINT8); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32); @@ -450,7 +456,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren std::vector < Ort::Value > inputTensors; switch (inputDatatype) { - case GST_TENSOR_TYPE_INT8: + case GST_TENSOR_TYPE_UINT8: uint8_t *src_data; if (inputTensorOffset == 00 && inputTensorScale == 1.0) { src_data = img_data; diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp index 21eaade85a..00b37e6e45 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp @@ -489,7 +489,7 @@ gst_onnx_inference_transform_caps (GstBaseTransform * onnxClient->getWidth (), "height", G_TYPE_INT, onnxClient->getHeight (), NULL); - if (onnxClient->getInputImageDatatype() == GST_TENSOR_TYPE_INT8 && + if (onnxClient->getInputImageDatatype() == GST_TENSOR_TYPE_UINT8 && onnxClient->getInputImageScale() == 1.0 && onnxClient->getInputImageOffset() == 0.0) { switch (onnxClient->getChannels()) { diff --git a/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h index b80f98dda2..49983308f4 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h +++ b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h @@ -27,21 +27,39 @@ /** * GstTensorType: * - * @GST_TENSOR_TYPE_INT8 8 bit integer tensor data - * @GST_TENSOR_TYPE_INT16 16 bit integer tensor data - * @GST_TENSOR_TYPE_INT32 32 bit integer tensor data + * @GST_TENSOR_TYPE_INT4 signed 4 bit integer tensor data + * @GST_TENSOR_TYPE_INT8 signed 8 bit integer tensor data + * @GST_TENSOR_TYPE_INT16 signed 16 bit integer tensor data + * @GST_TENSOR_TYPE_INT32 signed 32 bit integer tensor data + * @GST_TENSOR_TYPE_INT64 signed 64 bit integer tensor data + * @GST_TENSOR_TYPE_UINT4 unsigned 4 bit integer tensor data + * @GST_TENSOR_TYPE_UINT8 unsigned 8 bit integer tensor data + * @GST_TENSOR_TYPE_UINT16 unsigned 16 bit integer tensor data + * @GST_TENSOR_TYPE_UINT32 unsigned 32 bit integer tensor data + * @GST_TENSOR_TYPE_UINT64 unsigned 64 bit integer tensor data * @GST_TENSOR_TYPE_FLOAT16 16 bit floating point tensor data * @GST_TENSOR_TYPE_FLOAT32 32 bit floating point tensor data + * @GST_TENSOR_TYPE_FLOAT64 64 bit floating point tensor data + * @GST_TENSOR_TYPE_BFLOAT16 "brain" 16 bit floating point tensor data * * Since: 1.24 */ typedef enum _GstTensorType { + GST_TENSOR_TYPE_INT4, GST_TENSOR_TYPE_INT8, GST_TENSOR_TYPE_INT16, GST_TENSOR_TYPE_INT32, + GST_TENSOR_TYPE_INT64, + GST_TENSOR_TYPE_UINT4, + GST_TENSOR_TYPE_UINT8, + GST_TENSOR_TYPE_UINT16, + GST_TENSOR_TYPE_UINT32, + GST_TENSOR_TYPE_UINT64, GST_TENSOR_TYPE_FLOAT16, - GST_TENSOR_TYPE_FLOAT32 + GST_TENSOR_TYPE_FLOAT32, + GST_TENSOR_TYPE_FLOAT64, + GST_TENSOR_TYPE_BFLOAT16, } GstTensorType;