Add Execution Provider OpenVINO

This commit is contained in:
Elias Rosendahl 2025-06-09 12:49:22 +02:00
parent 5186341c28
commit 75652a481a
4 changed files with 148 additions and 318 deletions

View File

@ -1,35 +1,4 @@
{
"name": "GStreamer Dev Toolbox",
"image": "registry.freedesktop.org/gstreamer/gstreamer/amd64/fedora:gst-toolbox-main",
"containerUser": "containeruser",
"remoteUser": "containeruser",
"postCreateCommand": ["python3", "${containerWorkspaceFolder}/ci/scripts/handle-subprojects-cache.py" ,"--cache-dir", "/var/cache/subprojects", "subprojects/"],
"privileged": false,
"capAdd": [ "SYS_PTRACE" ],
"customizations": {
"vscode": {
"settings": {
"files.watcherExclude": {
"**/target/**": true
},
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff"
}
},
"extensions": [
"charliermarsh.ruff",
"GitLab.gitlab-workflow",
"mesonbuild.mesonbuild",
"ms-python.mypy-type-checker",
"ms-python.pylint",
"ms-python.python",
"ms-vscode.cpptools",
"redhat.vscode-xml",
"redhat.vscode-yaml",
"rust-lang.rust-analyzer",
"tamasfe.even-better-toml",
"vadimcn.vscode-lldb"
]
}
}
"name": "FT-Driverless Dev",
"image": "git.fasttube.de/ft-driverless/ft_as:gstreamer-plugin-bad"
}

View File

@ -22,26 +22,17 @@
#include "gstonnxclient.h"
#include <onnxruntime_cxx_api.h>
#ifdef HAVE_VSI_NPU
#include <core/providers/vsinpu/vsinpu_provider_factory.h>
#endif
#ifdef CPUPROVIDER_IN_SUBDIR
#include <core/providers/cpu/cpu_provider_factory.h>
#else
#include <cpu_provider_factory.h>
#endif
#include <onnxruntime/core/providers/cpu/cpu_provider_factory.h>
#include <onnxruntime/core/providers/openvino/openvino_provider_factory.h>
#include <sstream>
#define GST_CAT_DEFAULT onnx_inference_debug
/* FIXME: to be replaced by ModelInfo files */
#define GST_MODEL_OBJECT_DETECTOR_BOXES "ssd-mobilenet-v1-variant-1-out-boxes"
#define GST_MODEL_OBJECT_DETECTOR_SCORES "ssd-mobilenet-v1-variant-1-out-scores"
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "generic-variant-1-out-count"
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "ssd-mobilenet-v1-variant-1-out-classes"
/* FIXME: share this with tensordecoders, somehow? */
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
namespace GstOnnxNamespace
{
@ -62,32 +53,6 @@ namespace GstOnnxNamespace
return os;
}
const gint ONNX_TO_GST_TENSOR_DATATYPE[] = {
-1, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED*/
GST_TENSOR_DATA_TYPE_FLOAT32, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT*/
GST_TENSOR_DATA_TYPE_UINT8, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8*/
GST_TENSOR_DATA_TYPE_INT8, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8*/
GST_TENSOR_DATA_TYPE_UINT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16*/
GST_TENSOR_DATA_TYPE_INT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16*/
GST_TENSOR_DATA_TYPE_INT32, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32*/
GST_TENSOR_DATA_TYPE_INT64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64*/
GST_TENSOR_DATA_TYPE_STRING, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING*/
GST_TENSOR_DATA_TYPE_BOOL, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL*/
GST_TENSOR_DATA_TYPE_FLOAT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16*/
GST_TENSOR_DATA_TYPE_FLOAT64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE*/
GST_TENSOR_DATA_TYPE_UINT32, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32*/
GST_TENSOR_DATA_TYPE_UINT64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64*/
GST_TENSOR_DATA_TYPE_COMPLEX64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64*/
GST_TENSOR_DATA_TYPE_COMPLEX128, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128*/
GST_TENSOR_DATA_TYPE_BFLOAT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16*/
GST_TENSOR_DATA_TYPE_FLOAT8E4M3FN, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN*/
GST_TENSOR_DATA_TYPE_FLOAT8E4M3FNUZ, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ*/
GST_TENSOR_DATA_TYPE_FLOAT8E5M2, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2*/
GST_TENSOR_DATA_TYPE_FLOAT8E5M2FNUZ, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ*/
GST_TENSOR_DATA_TYPE_UINT4, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4*/
GST_TENSOR_DATA_TYPE_INT4, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4*/
};
GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_parent),
session (nullptr),
width (0),
@ -100,8 +65,9 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
inputDatatypeSize (sizeof (uint8_t)),
fixedInputImageSize (false),
inputTensorOffset (0.0),
inputTensorScale (1.0)
{
inputTensorScale (1.0),
provider_config(nullptr)
{
}
GstOnnxClient::~GstOnnxClient () {
@ -109,6 +75,10 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
delete[]dest;
}
void GstOnnxClient::setProviderConfig (const char *config)
{
provider_config = config;
}
int32_t GstOnnxClient::getWidth (void)
{
return width;
@ -163,7 +133,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
break;
default:
g_error ("Data type %d not handled", inputDatatype);
break;
break;
};
}
@ -208,45 +178,14 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return session != nullptr;
}
bool GstOnnxClient::setTensorDescDatatype(ONNXTensorElementDataType dt,
GstStructure *tensor_desc)
{
GValue val = G_VALUE_INIT;
GstTensorDataType gst_dt;
g_value_init(&val, G_TYPE_STRING);
if (dt > ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED &&
dt <= ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4) {
gst_dt = (GstTensorDataType)ONNX_TO_GST_TENSOR_DATATYPE [dt];
g_value_set_string (&val, gst_tensor_data_type_get_name(gst_dt));
} else {
GST_ERROR_OBJECT (debug_parent, "Unexpected datatype: %d", dt);
g_value_unset (&val);
return false;
}
gst_structure_take_value(tensor_desc, "type", &val);
g_value_unset(&val);
return true;
}
bool GstOnnxClient::createSession (std::string modelFile,
GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider,
GstStructure * tensors)
GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
{
OrtStatus *status;
if (session)
return true;
try {
Ort::SessionOptions sessionOptions;
const auto & api = Ort::GetApi ();
// for debugging
//sessionOptions.SetIntraOpNumThreads (1);
GraphOptimizationLevel onnx_optim;
switch (optim) {
GraphOptimizationLevel onnx_optim;
switch (optim) {
case GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL:
onnx_optim = GraphOptimizationLevel::ORT_DISABLE_ALL;
break;
@ -262,66 +201,61 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
default:
onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
break;
};
};
try {
Ort::SessionOptions sessionOptions;
const auto & api = Ort::GetApi ();
// for debugging
//sessionOptions.SetIntraOpNumThreads (1);
sessionOptions.SetGraphOptimizationLevel (onnx_optim);
m_provider = provider;
switch (m_provider) {
case GST_ONNX_EXECUTION_PROVIDER_CUDA:
case GST_ONNX_EXECUTION_PROVIDER_CUDA:
try {
OrtCUDAProviderOptionsV2 *cuda_options = nullptr;
Ort::ThrowOnError (api.CreateCUDAProviderOptions (&cuda_options));
std::unique_ptr < OrtCUDAProviderOptionsV2,
decltype (api.ReleaseCUDAProviderOptions) >
rel_cuda_options (cuda_options, api.ReleaseCUDAProviderOptions);
rel_cuda_options (cuda_options, api.ReleaseCUDAProviderOptions);
Ort::ThrowOnError (api.SessionOptionsAppendExecutionProvider_CUDA_V2
(static_cast < OrtSessionOptions * >(sessionOptions),
rel_cuda_options.get ()));
} catch (Ort::Exception & ortex) {
GST_WARNING
("Failed to create CUDA provider - dropping back to CPU");
}
catch (Ort::Exception & ortex) {
GST_WARNING
("Failed to create CUDA provider - dropping back to CPU");
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
(sessionOptions, 1));
}
break;
case GST_ONNX_EXECUTION_PROVIDER_OPENVINO: {
std::unordered_map<std::string, std::string> ovOptions;
if (this->provider_config) {
std::istringstream ss(this->provider_config);
std::string kv;
while (std::getline(ss, kv, ',')) {
auto pos = kv.find('=');
if (pos == std::string::npos) continue;
ovOptions[kv.substr(0, pos)] = kv.substr(pos + 1);
}
}
sessionOptions.AppendExecutionProvider("OpenVINO", ovOptions);
}
break;
default:
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
(sessionOptions, 1));
}
break;
#ifdef HAVE_VSI_NPU
case GST_ONNX_EXECUTION_PROVIDER_VSI:
try {
status = OrtSessionOptionsAppendExecutionProvider_VSINPU(sessionOptions);
if (status != nullptr) {
GST_ERROR_OBJECT (debug_parent,
"Failed to set VSINPU AI execution provider: %s",
Ort::GetApi().GetErrorMessage(status));
return false;
}
}
catch (Ort::Exception & ortex) {
GST_ERROR_OBJECT (debug_parent,
"Failed to set VSINPU AI execution provider: %s", ortex.what ());
return false;
}
sessionOptions.DisableCpuMemArena();
break;
#endif
default:
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
(sessionOptions, 1));
break;
}
env = Ort::Env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
break;
};
env =
Ort::Env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
"GstOnnxNamespace");
env.DisableTelemetryEvents();
session = new Ort::Session (env, modelFile.c_str (), sessionOptions);
auto inputTypeInfo = session->GetInputTypeInfo (0);
std::vector < int64_t > inputDims =
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
height = inputDims[1];
width = inputDims[2];
@ -334,23 +268,23 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
fixedInputImageSize = width > 0 && height > 0;
GST_DEBUG_OBJECT (debug_parent, "Number of Output Nodes: %d",
(gint) session->GetOutputCount ());
(gint) session->GetOutputCount ());
ONNXTensorElementDataType elementType =
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetElementType ();
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetElementType ();
switch (elementType) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
setInputImageDatatype(GST_TENSOR_DATA_TYPE_UINT8);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
setInputImageDatatype(GST_TENSOR_DATA_TYPE_FLOAT32);
break;
default:
GST_ERROR_OBJECT (debug_parent,
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
setInputImageDatatype(GST_TENSOR_DATA_TYPE_UINT8);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
setInputImageDatatype(GST_TENSOR_DATA_TYPE_FLOAT32);
break;
default:
GST_ERROR_OBJECT (debug_parent,
"Only input tensors of type int8 and floatare supported");
return false;
}
return false;
}
Ort::AllocatorWithDefaultOptions allocator;
auto input_name = session->GetInputNameAllocated (0, allocator);
@ -358,8 +292,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
for (size_t i = 0; i < session->GetOutputCount (); ++i) {
auto output_name = session->GetOutputNameAllocated (i, allocator);
GST_DEBUG_OBJECT (debug_parent, "Output name %lu:%s", i,
output_name.get ());
GST_DEBUG_OBJECT (debug_parent, "Output name %lu:%s", i, output_name.get ());
outputNames.push_back (std::move (output_name));
}
genOutputNamesRaw ();
@ -367,8 +300,8 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
// look up tensor ids
auto metaData = session->GetModelMetadata ();
OrtAllocator *ortAllocator;
status =
Ort::GetApi ().GetAllocatorWithDefaultOptions (&ortAllocator);
auto status =
Ort::GetApi ().GetAllocatorWithDefaultOptions (&ortAllocator);
if (status) {
// Handle the error case
const char *errorString = Ort::GetApi ().GetErrorMessage (status);
@ -379,26 +312,20 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return false;
}
size_t i = 0;
for (auto & name:outputNamesRaw) {
Ort::TypeInfo ti = session->GetOutputTypeInfo(i++);
auto type_shape = ti.GetTensorTypeAndShapeInfo();
auto card = type_shape.GetDimensionsCount();
auto type = type_shape.GetElementType();
Ort::AllocatedStringPtr res =
metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
if (res) {
if (res)
{
GQuark quark = g_quark_from_string (res.get ());
outputIds.push_back (quark);
} else if (g_str_has_prefix (name, "scores")) {
} else if (g_str_has_prefix (name, "detection_scores")) {
GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES);
GST_INFO_OBJECT(debug_parent,
"No custom metadata for key '%s', assuming %s",
name, GST_MODEL_OBJECT_DETECTOR_SCORES);
outputIds.push_back (quark);
} else if (g_str_has_prefix(name, "boxes")) {
} else if (g_str_has_prefix(name, "detection_boxes")) {
GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES);
GST_INFO_OBJECT(debug_parent,
"No custom metadata for key '%s', assuming %s",
@ -420,53 +347,16 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
GST_ERROR_OBJECT (debug_parent, "Failed to look up id for key %s", name);
return false;
}
GST_DEBUG_OBJECT (debug_parent, "Tensor %zu (%s) has id \"%s\"", i, name,
g_quark_to_string (outputIds.back ()));
/* tensor description */
GstStructure *tensor_desc = gst_structure_new_empty("tensor/strided");
/* Setting dims */
GValue val_dims = G_VALUE_INIT, val = G_VALUE_INIT;
gst_value_array_init(&val_dims, card);
g_value_init(&val, G_TYPE_INT);
for (auto &dim : type_shape.GetShape()) {
g_value_set_int(&val, dim > 0 ? dim : 0);
gst_value_array_append_value(&val_dims, &val);
}
gst_structure_take_value(tensor_desc, "dims", &val_dims);
g_value_unset(&val_dims);
g_value_unset(&val);
/* Setting datatype */
if (!setTensorDescDatatype(type, tensor_desc))
return false;
/* Setting tensors caps */
gst_structure_set(tensors, res.get(), GST_TYPE_CAPS,
gst_caps_new_full(tensor_desc, NULL), NULL);
}
} catch (Ort::Exception & ortex) {
}
catch (Ort::Exception & ortex) {
GST_ERROR_OBJECT (debug_parent, "%s", ortex.what ());
return false;
}
return true;
}
void GstOnnxClient::destroySession (void)
{
if (session == NULL)
return;
delete session;
session = NULL;
}
void GstOnnxClient::parseDimensions (GstVideoInfo vinfo)
{
int32_t newWidth = fixedInputImageSize ? width : vinfo.width;
@ -575,7 +465,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
std::ostringstream buffer;
buffer << inputDims;
GST_LOG_OBJECT (debug_parent, "Input dimensions: %s", buffer.str ().c_str ());
GST_DEBUG_OBJECT (debug_parent, "Input dimensions: %s", buffer.str ().c_str ());
// copy video frame
uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };

View File

@ -54,7 +54,7 @@ typedef enum
{
GST_ONNX_EXECUTION_PROVIDER_CPU,
GST_ONNX_EXECUTION_PROVIDER_CUDA,
GST_ONNX_EXECUTION_PROVIDER_VSI,
GST_ONNX_EXECUTION_PROVIDER_OPENVINO,
} GstOnnxExecutionProvider;
@ -65,10 +65,8 @@ namespace GstOnnxNamespace {
GstOnnxClient(GstElement *debug_parent);
~GstOnnxClient(void);
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
GstOnnxExecutionProvider provider, GstStructure *
tensors);
GstOnnxExecutionProvider provider);
bool hasSession(void);
void destroySession(void);
void setInputImageFormat(GstMlInputImageFormat format);
GstMlInputImageFormat getInputImageFormat(void);
GstTensorDataType getInputImageDatatype(void);
@ -85,6 +83,7 @@ namespace GstOnnxNamespace {
GstTensorMeta *copy_tensors_to_meta (std::vector<Ort::Value> &outputs,
GstBuffer *buffer);
void parseDimensions(GstVideoInfo vinfo);
void setProviderConfig(const char *config);
private:
GstElement *debug_parent;
@ -93,7 +92,6 @@ namespace GstOnnxNamespace {
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
bool setTensorDescDatatype (ONNXTensorElementDataType dt, GstStructure * tensor_desc);
Ort::Env env;
Ort::Session * session;
int32_t width;
@ -112,6 +110,7 @@ namespace GstOnnxNamespace {
bool fixedInputImageSize;
float inputTensorOffset;
float inputTensorScale;
const char *provider_config;
};
}

View File

@ -72,6 +72,7 @@
* @optimization_level: ONNX session optimization level
* @execution_provider: ONNX execution provider
* @onnx_client opaque pointer to ONNX client
* @onnx_disabled true if inference is disabled
* @video_info @ref GstVideoInfo of sink caps
*/
struct _GstOnnxInference
@ -81,8 +82,9 @@ struct _GstOnnxInference
GstOnnxOptimizationLevel optimization_level;
GstOnnxExecutionProvider execution_provider;
gpointer onnx_client;
gboolean onnx_disabled;
GstVideoInfo video_info;
GstStructure *tensors;
gchar *provider_config;
};
GST_DEBUG_CATEGORY (onnx_inference_debug);
@ -100,6 +102,7 @@ enum
PROP_INPUT_IMAGE_FORMAT,
PROP_OPTIMIZATION_LEVEL,
PROP_EXECUTION_PROVIDER,
PROP_PROVIDER_CONFIG,
PROP_INPUT_OFFSET,
PROP_INPUT_SCALE
};
@ -130,13 +133,12 @@ static GstFlowReturn gst_onnx_inference_transform_ip (GstBaseTransform *
trans, GstBuffer * buf);
static gboolean gst_onnx_inference_process (GstBaseTransform * trans,
GstBuffer * buf);
static gboolean gst_onnx_inference_create_session (GstBaseTransform * trans);
static GstCaps *gst_onnx_inference_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
static gboolean
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps);
static gboolean gst_onnx_inference_start (GstBaseTransform * trans);
static gboolean gst_onnx_inference_stop (GstBaseTransform * trans);
G_DEFINE_TYPE (GstOnnxInference, gst_onnx_inference, GST_TYPE_BASE_TRANSFORM);
@ -187,24 +189,12 @@ gst_onnx_execution_provider_get_type (void)
static GEnumValue execution_provider_types[] = {
{GST_ONNX_EXECUTION_PROVIDER_CPU, "CPU execution provider",
"cpu"},
#if HAVE_CUDA
{GST_ONNX_EXECUTION_PROVIDER_CUDA,
"CUDA execution provider",
"cuda"},
#else
{GST_ONNX_EXECUTION_PROVIDER_CUDA,
"CUDA execution provider (compiled out, will use CPU)",
"cuda"},
#endif
#ifdef HAVE_VSI_NPU
{GST_ONNX_EXECUTION_PROVIDER_VSI,
"VeriSilicon NPU execution provider",
"vsi"},
#else
{GST_ONNX_EXECUTION_PROVIDER_VSI,
"VeriSilicon NPU execution provider (compiled out, will use CPU)",
"vsi"},
#endif
{GST_ONNX_EXECUTION_PROVIDER_OPENVINO,
"OPENVINO execution provider",
"openvino"},
{0, NULL, NULL},
};
@ -331,6 +321,14 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
G_MINFLOAT, G_MAXFLOAT, 1.0,
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_PROVIDER_CONFIG,
g_param_spec_string ("provider-config",
"Provider config",
"Comma-separierte Key=Value-Optionen",
nullptr,
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "onnxinference",
"Filter/Effect/Video",
@ -346,10 +344,6 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_caps);
basetransform_class->set_caps =
GST_DEBUG_FUNCPTR (gst_onnx_inference_set_caps);
basetransform_class->start =
GST_DEBUG_FUNCPTR(gst_onnx_inference_start);
basetransform_class->stop =
GST_DEBUG_FUNCPTR(gst_onnx_inference_stop);
gst_type_mark_as_plugin_api (GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
(GstPluginAPIFlags) 0);
@ -363,18 +357,16 @@ static void
gst_onnx_inference_init (GstOnnxInference * self)
{
self->onnx_client = new GstOnnxNamespace::GstOnnxClient (GST_ELEMENT(self));
/* TODO: at the moment onnx inference only support video output. We
* should revisit this once we generalize this aspect */
self->tensors = gst_structure_new_empty ("video/x-raw");
self->onnx_disabled = TRUE;
}
static void
gst_onnx_inference_finalize (GObject * object)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (object);
g_free (self->provider_config);
self->provider_config = NULL;
g_free (self->model_file);
gst_structure_free(self->tensors);
delete GST_ONNX_CLIENT_MEMBER (self);
G_OBJECT_CLASS (gst_onnx_inference_parent_class)->finalize (object);
}
@ -396,6 +388,7 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
if (self->model_file)
g_free (self->model_file);
self->model_file = g_strdup (filename);
self->onnx_disabled = FALSE;
} else {
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
}
@ -418,6 +411,11 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
case PROP_INPUT_SCALE:
onnxClient->setInputImageScale (g_value_get_float (value));
break;
case PROP_PROVIDER_CONFIG:
g_free (self->provider_config);
self->provider_config = g_value_dup_string (value);
onnxClient->setProviderConfig(self->provider_config);
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
@ -456,6 +454,45 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id,
}
}
static gboolean
gst_onnx_inference_create_session (GstBaseTransform * trans)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
GST_OBJECT_LOCK (self);
if (self->onnx_disabled) {
GST_OBJECT_UNLOCK (self);
return FALSE;
}
if (onnxClient->hasSession ()) {
GST_OBJECT_UNLOCK (self);
return TRUE;
}
if (self->model_file) {
gboolean ret =
GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
self->optimization_level,
self->execution_provider);
if (!ret) {
GST_ERROR_OBJECT (self,
"Unable to create ONNX session. Model is disabled.");
self->onnx_disabled = TRUE;
}
} else {
self->onnx_disabled = TRUE;
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL), ("Model file not found"));
}
GST_OBJECT_UNLOCK (self);
if (self->onnx_disabled) {
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
}
return TRUE;
}
static GstCaps *
gst_onnx_inference_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
@ -464,17 +501,9 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
GstCaps *other_caps;
GstCaps *restrictions;
bool has_session;
GST_OBJECT_LOCK (self);
has_session = onnxClient->hasSession ();
GST_OBJECT_UNLOCK (self);
if (!has_session) {
other_caps = gst_caps_ref (caps);
goto done;
}
if (!gst_onnx_inference_create_session (trans))
return NULL;
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
if (gst_base_transform_is_passthrough (trans))
@ -528,20 +557,10 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
GST_DEBUG_OBJECT(self, "Applying caps restrictions: %" GST_PTR_FORMAT,
restrictions);
if (direction == GST_PAD_SINK) {
GstCaps * tensors_caps = gst_caps_new_full (gst_structure_copy (
self->tensors), NULL);
GstCaps *intersect = gst_caps_intersect (restrictions, tensors_caps);
gst_caps_replace (&restrictions, intersect);
gst_caps_unref (tensors_caps);
gst_caps_unref (intersect);
}
other_caps = gst_caps_intersect_full (caps, restrictions,
GST_CAPS_INTERSECT_FIRST);
gst_caps_unref (restrictions);
done:
if (filter_caps) {
GstCaps *tmp = gst_caps_intersect_full (
other_caps, filter_caps, GST_CAPS_INTERSECT_FIRST);
@ -552,53 +571,6 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
return other_caps;
}
static gboolean
gst_onnx_inference_start (GstBaseTransform * trans)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
gboolean ret = FALSE;
GST_OBJECT_LOCK (self);
if (onnxClient->hasSession ()) {
ret = TRUE;
goto done;
}
if (self->model_file == NULL) {
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
("model-file property not set"));
goto done;
}
ret = GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
self->optimization_level,
self->execution_provider,
self->tensors);
if (!ret)
GST_ERROR_OBJECT (self,
"Unable to create ONNX session. Model is disabled.");
done:
GST_OBJECT_UNLOCK (self);
return ret;
}
static gboolean
gst_onnx_inference_stop (GstBaseTransform * trans)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
GST_OBJECT_LOCK (self);
if (onnxClient->hasSession ())
GST_ONNX_CLIENT_MEMBER (self)->destroySession ();
GST_OBJECT_UNLOCK (self);
return TRUE;
}
static gboolean
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps)