onnx: produce tensor caps

- Add tensor description to srcpads caps

onnx: formatting

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/9172>
This commit is contained in:
Daniel Morin 2025-06-03 23:05:18 -04:00 committed by GStreamer Marge Bot
parent 9bd3a3be74
commit 28fafc5488
3 changed files with 104 additions and 4 deletions

View File

@ -21,6 +21,7 @@
*/
#include "gstonnxclient.h"
#include <onnxruntime_cxx_api.h>
#include <cpu_provider_factory.h>
#include <sstream>
@ -51,6 +52,32 @@ namespace GstOnnxNamespace
return os;
}
const gint ONNX_TO_GST_TENSOR_DATATYPE[] = {
-1, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED*/
GST_TENSOR_DATA_TYPE_FLOAT32, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT*/
GST_TENSOR_DATA_TYPE_UINT8, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8*/
GST_TENSOR_DATA_TYPE_INT8, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8*/
GST_TENSOR_DATA_TYPE_UINT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16*/
GST_TENSOR_DATA_TYPE_INT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16*/
GST_TENSOR_DATA_TYPE_INT32, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32*/
GST_TENSOR_DATA_TYPE_INT64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64*/
GST_TENSOR_DATA_TYPE_STRING, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING*/
GST_TENSOR_DATA_TYPE_BOOL, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL*/
GST_TENSOR_DATA_TYPE_FLOAT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16*/
GST_TENSOR_DATA_TYPE_FLOAT64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE*/
GST_TENSOR_DATA_TYPE_UINT32, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32*/
GST_TENSOR_DATA_TYPE_UINT64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64*/
GST_TENSOR_DATA_TYPE_COMPLEX64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64*/
GST_TENSOR_DATA_TYPE_COMPLEX128, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128*/
GST_TENSOR_DATA_TYPE_BFLOAT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16*/
GST_TENSOR_DATA_TYPE_FLOAT8E4M3FN, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN*/
GST_TENSOR_DATA_TYPE_FLOAT8E4M3FNUZ, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ*/
GST_TENSOR_DATA_TYPE_FLOAT8E5M2, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2*/
GST_TENSOR_DATA_TYPE_FLOAT8E5M2FNUZ, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ*/
GST_TENSOR_DATA_TYPE_UINT4, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4*/
GST_TENSOR_DATA_TYPE_INT4, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4*/
};
GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_parent),
session (nullptr),
width (0),
@ -171,8 +198,32 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return session != nullptr;
}
bool GstOnnxClient::setTensorDescDatatype(ONNXTensorElementDataType dt,
GstStructure *tensor_desc)
{
GValue val = G_VALUE_INIT;
GstTensorDataType gst_dt;
g_value_init(&val, G_TYPE_STRING);
if (dt > ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED &&
dt <= ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4) {
gst_dt = (GstTensorDataType)ONNX_TO_GST_TENSOR_DATATYPE [dt];
g_value_set_string (&val, gst_tensor_data_type_get_name(gst_dt));
} else {
GST_ERROR_OBJECT (debug_parent, "Unexpected datatype: %d", dt);
g_value_unset (&val);
return false;
}
gst_structure_take_value(tensor_desc, "type", &val);
g_value_unset(&val);
return true;
}
bool GstOnnxClient::createSession (std::string modelFile,
GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider,
GstStructure * tensors)
{
if (session)
return true;
@ -291,7 +342,14 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return false;
}
size_t i = 0;
for (auto & name:outputNamesRaw) {
Ort::TypeInfo ti = session->GetOutputTypeInfo(i++);
auto type_shape = ti.GetTensorTypeAndShapeInfo();
auto card = type_shape.GetDimensionsCount();
auto type = type_shape.GetElementType();
Ort::AllocatedStringPtr res =
metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
if (res)
@ -323,9 +381,34 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
name, GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS);
outputIds.push_back (quark);
} else {
GST_ERROR_OBJECT (debug_parent, "Failed to look up id for key %s", name);
GST_ERROR_OBJECT (debug_parent, "Failed to look up id for key %s",
name);
return false;
}
/* tensor description */
GstStructure *tensor_desc = gst_structure_new_empty("tensor/strided");
/* Setting dims */
GValue val_dims = G_VALUE_INIT, val = G_VALUE_INIT;
gst_value_array_init(&val_dims, card);
g_value_init(&val, G_TYPE_INT);
for (auto &dim : type_shape.GetShape()) {
g_value_set_int(&val, dim > 0 ? dim : 0);
gst_value_array_append_value(&val_dims, &val);
}
gst_structure_take_value(tensor_desc, "dims", &val_dims);
g_value_unset(&val_dims);
g_value_unset(&val);
/* Setting datatype */
if (!setTensorDescDatatype(type, tensor_desc))
return false;
/* Setting tensors caps */
gst_structure_set(tensors, res.get(), GST_TYPE_CAPS,
gst_caps_new_full(tensor_desc, NULL), NULL);
}
}
catch (Ort::Exception & ortex) {

View File

@ -64,7 +64,8 @@ namespace GstOnnxNamespace {
GstOnnxClient(GstElement *debug_parent);
~GstOnnxClient(void);
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
GstOnnxExecutionProvider provider);
GstOnnxExecutionProvider provider, GstStructure *
tensors);
bool hasSession(void);
void setInputImageFormat(GstMlInputImageFormat format);
GstMlInputImageFormat getInputImageFormat(void);
@ -90,6 +91,7 @@ namespace GstOnnxNamespace {
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
bool setTensorDescDatatype (ONNXTensorElementDataType dt, GstStructure * tensor_desc);
Ort::Env env;
Ort::Session * session;
int32_t width;

View File

@ -84,6 +84,7 @@ struct _GstOnnxInference
gpointer onnx_client;
gboolean onnx_disabled;
GstVideoInfo video_info;
GstStructure *tensors;
};
GST_DEBUG_CATEGORY (onnx_inference_debug);
@ -345,6 +346,10 @@ gst_onnx_inference_init (GstOnnxInference * self)
{
self->onnx_client = new GstOnnxNamespace::GstOnnxClient (GST_ELEMENT(self));
self->onnx_disabled = TRUE;
/* TODO: at the moment onnx inference only support video output. We
* should revisit this once we generalize this aspect */
self->tensors = gst_structure_new_empty ("video/x-raw");
}
static void
@ -353,6 +358,7 @@ gst_onnx_inference_finalize (GObject * object)
GstOnnxInference *self = GST_ONNX_INFERENCE (object);
g_free (self->model_file);
gst_structure_free(self->tensors);
delete GST_ONNX_CLIENT_MEMBER (self);
G_OBJECT_CLASS (gst_onnx_inference_parent_class)->finalize (object);
}
@ -456,7 +462,7 @@ gst_onnx_inference_create_session (GstBaseTransform * trans)
gboolean ret =
GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
self->optimization_level,
self->execution_provider);
self->execution_provider, self->tensors);
if (!ret) {
GST_ERROR_OBJECT (self,
"Unable to create ONNX session. Model is disabled.");
@ -538,6 +544,15 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
GST_DEBUG_OBJECT(self, "Applying caps restrictions: %" GST_PTR_FORMAT,
restrictions);
if (direction == GST_PAD_SINK) {
GstCaps * tensors_caps = gst_caps_new_full (gst_structure_copy (
self->tensors), NULL);
GstCaps *intersect = gst_caps_intersect (restrictions, tensors_caps);
gst_caps_replace (&restrictions, intersect);
gst_caps_unref (tensors_caps);
gst_caps_unref (intersect);
}
other_caps = gst_caps_intersect_full (caps, restrictions,
GST_CAPS_INTERSECT_FIRST);
gst_caps_unref (restrictions);