Add Execution Provider OpenVINO
This commit is contained in:
parent
5186341c28
commit
75652a481a
@ -1,35 +1,4 @@
|
||||
{
|
||||
"name": "GStreamer Dev Toolbox",
|
||||
"image": "registry.freedesktop.org/gstreamer/gstreamer/amd64/fedora:gst-toolbox-main",
|
||||
"containerUser": "containeruser",
|
||||
"remoteUser": "containeruser",
|
||||
"postCreateCommand": ["python3", "${containerWorkspaceFolder}/ci/scripts/handle-subprojects-cache.py" ,"--cache-dir", "/var/cache/subprojects", "subprojects/"],
|
||||
"privileged": false,
|
||||
"capAdd": [ "SYS_PTRACE" ],
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"settings": {
|
||||
"files.watcherExclude": {
|
||||
"**/target/**": true
|
||||
},
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "charliermarsh.ruff"
|
||||
}
|
||||
},
|
||||
"extensions": [
|
||||
"charliermarsh.ruff",
|
||||
"GitLab.gitlab-workflow",
|
||||
"mesonbuild.mesonbuild",
|
||||
"ms-python.mypy-type-checker",
|
||||
"ms-python.pylint",
|
||||
"ms-python.python",
|
||||
"ms-vscode.cpptools",
|
||||
"redhat.vscode-xml",
|
||||
"redhat.vscode-yaml",
|
||||
"rust-lang.rust-analyzer",
|
||||
"tamasfe.even-better-toml",
|
||||
"vadimcn.vscode-lldb"
|
||||
]
|
||||
}
|
||||
}
|
||||
"name": "FT-Driverless Dev",
|
||||
"image": "git.fasttube.de/ft-driverless/ft_as:gstreamer-plugin-bad"
|
||||
}
|
||||
|
@ -22,26 +22,17 @@
|
||||
|
||||
#include "gstonnxclient.h"
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
#ifdef HAVE_VSI_NPU
|
||||
#include <core/providers/vsinpu/vsinpu_provider_factory.h>
|
||||
#endif
|
||||
|
||||
#ifdef CPUPROVIDER_IN_SUBDIR
|
||||
#include <core/providers/cpu/cpu_provider_factory.h>
|
||||
#else
|
||||
#include <cpu_provider_factory.h>
|
||||
#endif
|
||||
|
||||
#include <onnxruntime/core/providers/cpu/cpu_provider_factory.h>
|
||||
#include <onnxruntime/core/providers/openvino/openvino_provider_factory.h>
|
||||
#include <sstream>
|
||||
|
||||
#define GST_CAT_DEFAULT onnx_inference_debug
|
||||
|
||||
/* FIXME: to be replaced by ModelInfo files */
|
||||
#define GST_MODEL_OBJECT_DETECTOR_BOXES "ssd-mobilenet-v1-variant-1-out-boxes"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_SCORES "ssd-mobilenet-v1-variant-1-out-scores"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "generic-variant-1-out-count"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "ssd-mobilenet-v1-variant-1-out-classes"
|
||||
/* FIXME: share this with tensordecoders, somehow? */
|
||||
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
|
||||
|
||||
namespace GstOnnxNamespace
|
||||
{
|
||||
@ -62,32 +53,6 @@ namespace GstOnnxNamespace
|
||||
return os;
|
||||
}
|
||||
|
||||
const gint ONNX_TO_GST_TENSOR_DATATYPE[] = {
|
||||
-1, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED*/
|
||||
GST_TENSOR_DATA_TYPE_FLOAT32, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT*/
|
||||
GST_TENSOR_DATA_TYPE_UINT8, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8*/
|
||||
GST_TENSOR_DATA_TYPE_INT8, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8*/
|
||||
GST_TENSOR_DATA_TYPE_UINT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16*/
|
||||
GST_TENSOR_DATA_TYPE_INT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16*/
|
||||
GST_TENSOR_DATA_TYPE_INT32, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32*/
|
||||
GST_TENSOR_DATA_TYPE_INT64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64*/
|
||||
GST_TENSOR_DATA_TYPE_STRING, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING*/
|
||||
GST_TENSOR_DATA_TYPE_BOOL, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL*/
|
||||
GST_TENSOR_DATA_TYPE_FLOAT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16*/
|
||||
GST_TENSOR_DATA_TYPE_FLOAT64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE*/
|
||||
GST_TENSOR_DATA_TYPE_UINT32, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32*/
|
||||
GST_TENSOR_DATA_TYPE_UINT64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64*/
|
||||
GST_TENSOR_DATA_TYPE_COMPLEX64, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64*/
|
||||
GST_TENSOR_DATA_TYPE_COMPLEX128, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128*/
|
||||
GST_TENSOR_DATA_TYPE_BFLOAT16, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16*/
|
||||
GST_TENSOR_DATA_TYPE_FLOAT8E4M3FN, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN*/
|
||||
GST_TENSOR_DATA_TYPE_FLOAT8E4M3FNUZ, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ*/
|
||||
GST_TENSOR_DATA_TYPE_FLOAT8E5M2, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2*/
|
||||
GST_TENSOR_DATA_TYPE_FLOAT8E5M2FNUZ, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ*/
|
||||
GST_TENSOR_DATA_TYPE_UINT4, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4*/
|
||||
GST_TENSOR_DATA_TYPE_INT4, /* ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4*/
|
||||
};
|
||||
|
||||
GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_parent),
|
||||
session (nullptr),
|
||||
width (0),
|
||||
@ -100,8 +65,9 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
inputDatatypeSize (sizeof (uint8_t)),
|
||||
fixedInputImageSize (false),
|
||||
inputTensorOffset (0.0),
|
||||
inputTensorScale (1.0)
|
||||
{
|
||||
inputTensorScale (1.0),
|
||||
provider_config(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
GstOnnxClient::~GstOnnxClient () {
|
||||
@ -109,6 +75,10 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
delete[]dest;
|
||||
}
|
||||
|
||||
void GstOnnxClient::setProviderConfig (const char *config)
|
||||
{
|
||||
provider_config = config;
|
||||
}
|
||||
int32_t GstOnnxClient::getWidth (void)
|
||||
{
|
||||
return width;
|
||||
@ -163,7 +133,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
break;
|
||||
default:
|
||||
g_error ("Data type %d not handled", inputDatatype);
|
||||
break;
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
@ -208,45 +178,14 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
return session != nullptr;
|
||||
}
|
||||
|
||||
bool GstOnnxClient::setTensorDescDatatype(ONNXTensorElementDataType dt,
|
||||
GstStructure *tensor_desc)
|
||||
{
|
||||
GValue val = G_VALUE_INIT;
|
||||
GstTensorDataType gst_dt;
|
||||
|
||||
g_value_init(&val, G_TYPE_STRING);
|
||||
|
||||
if (dt > ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED &&
|
||||
dt <= ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4) {
|
||||
gst_dt = (GstTensorDataType)ONNX_TO_GST_TENSOR_DATATYPE [dt];
|
||||
g_value_set_string (&val, gst_tensor_data_type_get_name(gst_dt));
|
||||
} else {
|
||||
GST_ERROR_OBJECT (debug_parent, "Unexpected datatype: %d", dt);
|
||||
g_value_unset (&val);
|
||||
return false;
|
||||
}
|
||||
|
||||
gst_structure_take_value(tensor_desc, "type", &val);
|
||||
g_value_unset(&val);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GstOnnxClient::createSession (std::string modelFile,
|
||||
GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider,
|
||||
GstStructure * tensors)
|
||||
GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
|
||||
{
|
||||
OrtStatus *status;
|
||||
if (session)
|
||||
return true;
|
||||
|
||||
try {
|
||||
Ort::SessionOptions sessionOptions;
|
||||
const auto & api = Ort::GetApi ();
|
||||
// for debugging
|
||||
//sessionOptions.SetIntraOpNumThreads (1);
|
||||
|
||||
GraphOptimizationLevel onnx_optim;
|
||||
switch (optim) {
|
||||
GraphOptimizationLevel onnx_optim;
|
||||
switch (optim) {
|
||||
case GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL:
|
||||
onnx_optim = GraphOptimizationLevel::ORT_DISABLE_ALL;
|
||||
break;
|
||||
@ -262,66 +201,61 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
default:
|
||||
onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
|
||||
break;
|
||||
};
|
||||
};
|
||||
|
||||
try {
|
||||
Ort::SessionOptions sessionOptions;
|
||||
const auto & api = Ort::GetApi ();
|
||||
// for debugging
|
||||
//sessionOptions.SetIntraOpNumThreads (1);
|
||||
sessionOptions.SetGraphOptimizationLevel (onnx_optim);
|
||||
|
||||
m_provider = provider;
|
||||
switch (m_provider) {
|
||||
case GST_ONNX_EXECUTION_PROVIDER_CUDA:
|
||||
case GST_ONNX_EXECUTION_PROVIDER_CUDA:
|
||||
try {
|
||||
OrtCUDAProviderOptionsV2 *cuda_options = nullptr;
|
||||
Ort::ThrowOnError (api.CreateCUDAProviderOptions (&cuda_options));
|
||||
std::unique_ptr < OrtCUDAProviderOptionsV2,
|
||||
decltype (api.ReleaseCUDAProviderOptions) >
|
||||
rel_cuda_options (cuda_options, api.ReleaseCUDAProviderOptions);
|
||||
rel_cuda_options (cuda_options, api.ReleaseCUDAProviderOptions);
|
||||
Ort::ThrowOnError (api.SessionOptionsAppendExecutionProvider_CUDA_V2
|
||||
(static_cast < OrtSessionOptions * >(sessionOptions),
|
||||
rel_cuda_options.get ()));
|
||||
} catch (Ort::Exception & ortex) {
|
||||
GST_WARNING
|
||||
("Failed to create CUDA provider - dropping back to CPU");
|
||||
}
|
||||
catch (Ort::Exception & ortex) {
|
||||
GST_WARNING
|
||||
("Failed to create CUDA provider - dropping back to CPU");
|
||||
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
|
||||
(sessionOptions, 1));
|
||||
}
|
||||
break;
|
||||
case GST_ONNX_EXECUTION_PROVIDER_OPENVINO: {
|
||||
std::unordered_map<std::string, std::string> ovOptions;
|
||||
if (this->provider_config) {
|
||||
std::istringstream ss(this->provider_config);
|
||||
std::string kv;
|
||||
while (std::getline(ss, kv, ',')) {
|
||||
auto pos = kv.find('=');
|
||||
if (pos == std::string::npos) continue;
|
||||
ovOptions[kv.substr(0, pos)] = kv.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
sessionOptions.AppendExecutionProvider("OpenVINO", ovOptions);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
|
||||
(sessionOptions, 1));
|
||||
}
|
||||
break;
|
||||
#ifdef HAVE_VSI_NPU
|
||||
case GST_ONNX_EXECUTION_PROVIDER_VSI:
|
||||
try {
|
||||
|
||||
status = OrtSessionOptionsAppendExecutionProvider_VSINPU(sessionOptions);
|
||||
if (status != nullptr) {
|
||||
GST_ERROR_OBJECT (debug_parent,
|
||||
"Failed to set VSINPU AI execution provider: %s",
|
||||
Ort::GetApi().GetErrorMessage(status));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
catch (Ort::Exception & ortex) {
|
||||
GST_ERROR_OBJECT (debug_parent,
|
||||
"Failed to set VSINPU AI execution provider: %s", ortex.what ());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
sessionOptions.DisableCpuMemArena();
|
||||
break;
|
||||
#endif
|
||||
|
||||
default:
|
||||
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
|
||||
(sessionOptions, 1));
|
||||
break;
|
||||
}
|
||||
|
||||
env = Ort::Env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
|
||||
break;
|
||||
};
|
||||
env =
|
||||
Ort::Env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
|
||||
"GstOnnxNamespace");
|
||||
env.DisableTelemetryEvents();
|
||||
session = new Ort::Session (env, modelFile.c_str (), sessionOptions);
|
||||
|
||||
auto inputTypeInfo = session->GetInputTypeInfo (0);
|
||||
std::vector < int64_t > inputDims =
|
||||
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
|
||||
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
|
||||
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
|
||||
height = inputDims[1];
|
||||
width = inputDims[2];
|
||||
@ -334,23 +268,23 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
|
||||
fixedInputImageSize = width > 0 && height > 0;
|
||||
GST_DEBUG_OBJECT (debug_parent, "Number of Output Nodes: %d",
|
||||
(gint) session->GetOutputCount ());
|
||||
(gint) session->GetOutputCount ());
|
||||
|
||||
ONNXTensorElementDataType elementType =
|
||||
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetElementType ();
|
||||
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetElementType ();
|
||||
|
||||
switch (elementType) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
setInputImageDatatype(GST_TENSOR_DATA_TYPE_UINT8);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
setInputImageDatatype(GST_TENSOR_DATA_TYPE_FLOAT32);
|
||||
break;
|
||||
default:
|
||||
GST_ERROR_OBJECT (debug_parent,
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
setInputImageDatatype(GST_TENSOR_DATA_TYPE_UINT8);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
setInputImageDatatype(GST_TENSOR_DATA_TYPE_FLOAT32);
|
||||
break;
|
||||
default:
|
||||
GST_ERROR_OBJECT (debug_parent,
|
||||
"Only input tensors of type int8 and floatare supported");
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto input_name = session->GetInputNameAllocated (0, allocator);
|
||||
@ -358,8 +292,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
|
||||
for (size_t i = 0; i < session->GetOutputCount (); ++i) {
|
||||
auto output_name = session->GetOutputNameAllocated (i, allocator);
|
||||
GST_DEBUG_OBJECT (debug_parent, "Output name %lu:%s", i,
|
||||
output_name.get ());
|
||||
GST_DEBUG_OBJECT (debug_parent, "Output name %lu:%s", i, output_name.get ());
|
||||
outputNames.push_back (std::move (output_name));
|
||||
}
|
||||
genOutputNamesRaw ();
|
||||
@ -367,8 +300,8 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
// look up tensor ids
|
||||
auto metaData = session->GetModelMetadata ();
|
||||
OrtAllocator *ortAllocator;
|
||||
status =
|
||||
Ort::GetApi ().GetAllocatorWithDefaultOptions (&ortAllocator);
|
||||
auto status =
|
||||
Ort::GetApi ().GetAllocatorWithDefaultOptions (&ortAllocator);
|
||||
if (status) {
|
||||
// Handle the error case
|
||||
const char *errorString = Ort::GetApi ().GetErrorMessage (status);
|
||||
@ -379,26 +312,20 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t i = 0;
|
||||
for (auto & name:outputNamesRaw) {
|
||||
Ort::TypeInfo ti = session->GetOutputTypeInfo(i++);
|
||||
auto type_shape = ti.GetTensorTypeAndShapeInfo();
|
||||
auto card = type_shape.GetDimensionsCount();
|
||||
auto type = type_shape.GetElementType();
|
||||
Ort::AllocatedStringPtr res =
|
||||
metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
|
||||
|
||||
if (res) {
|
||||
if (res)
|
||||
{
|
||||
GQuark quark = g_quark_from_string (res.get ());
|
||||
outputIds.push_back (quark);
|
||||
} else if (g_str_has_prefix (name, "scores")) {
|
||||
} else if (g_str_has_prefix (name, "detection_scores")) {
|
||||
GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES);
|
||||
GST_INFO_OBJECT(debug_parent,
|
||||
"No custom metadata for key '%s', assuming %s",
|
||||
name, GST_MODEL_OBJECT_DETECTOR_SCORES);
|
||||
outputIds.push_back (quark);
|
||||
} else if (g_str_has_prefix(name, "boxes")) {
|
||||
} else if (g_str_has_prefix(name, "detection_boxes")) {
|
||||
GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES);
|
||||
GST_INFO_OBJECT(debug_parent,
|
||||
"No custom metadata for key '%s', assuming %s",
|
||||
@ -420,53 +347,16 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
GST_ERROR_OBJECT (debug_parent, "Failed to look up id for key %s", name);
|
||||
return false;
|
||||
}
|
||||
|
||||
GST_DEBUG_OBJECT (debug_parent, "Tensor %zu (%s) has id \"%s\"", i, name,
|
||||
g_quark_to_string (outputIds.back ()));
|
||||
|
||||
/* tensor description */
|
||||
GstStructure *tensor_desc = gst_structure_new_empty("tensor/strided");
|
||||
|
||||
/* Setting dims */
|
||||
GValue val_dims = G_VALUE_INIT, val = G_VALUE_INIT;
|
||||
gst_value_array_init(&val_dims, card);
|
||||
g_value_init(&val, G_TYPE_INT);
|
||||
|
||||
for (auto &dim : type_shape.GetShape()) {
|
||||
g_value_set_int(&val, dim > 0 ? dim : 0);
|
||||
gst_value_array_append_value(&val_dims, &val);
|
||||
}
|
||||
gst_structure_take_value(tensor_desc, "dims", &val_dims);
|
||||
g_value_unset(&val_dims);
|
||||
g_value_unset(&val);
|
||||
|
||||
/* Setting datatype */
|
||||
if (!setTensorDescDatatype(type, tensor_desc))
|
||||
return false;
|
||||
|
||||
/* Setting tensors caps */
|
||||
gst_structure_set(tensors, res.get(), GST_TYPE_CAPS,
|
||||
gst_caps_new_full(tensor_desc, NULL), NULL);
|
||||
}
|
||||
|
||||
} catch (Ort::Exception & ortex) {
|
||||
}
|
||||
catch (Ort::Exception & ortex) {
|
||||
GST_ERROR_OBJECT (debug_parent, "%s", ortex.what ());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void GstOnnxClient::destroySession (void)
|
||||
{
|
||||
if (session == NULL)
|
||||
return;
|
||||
|
||||
delete session;
|
||||
session = NULL;
|
||||
}
|
||||
|
||||
void GstOnnxClient::parseDimensions (GstVideoInfo vinfo)
|
||||
{
|
||||
int32_t newWidth = fixedInputImageSize ? width : vinfo.width;
|
||||
@ -575,7 +465,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||
|
||||
std::ostringstream buffer;
|
||||
buffer << inputDims;
|
||||
GST_LOG_OBJECT (debug_parent, "Input dimensions: %s", buffer.str ().c_str ());
|
||||
GST_DEBUG_OBJECT (debug_parent, "Input dimensions: %s", buffer.str ().c_str ());
|
||||
|
||||
// copy video frame
|
||||
uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
|
||||
|
@ -54,7 +54,7 @@ typedef enum
|
||||
{
|
||||
GST_ONNX_EXECUTION_PROVIDER_CPU,
|
||||
GST_ONNX_EXECUTION_PROVIDER_CUDA,
|
||||
GST_ONNX_EXECUTION_PROVIDER_VSI,
|
||||
GST_ONNX_EXECUTION_PROVIDER_OPENVINO,
|
||||
} GstOnnxExecutionProvider;
|
||||
|
||||
|
||||
@ -65,10 +65,8 @@ namespace GstOnnxNamespace {
|
||||
GstOnnxClient(GstElement *debug_parent);
|
||||
~GstOnnxClient(void);
|
||||
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
|
||||
GstOnnxExecutionProvider provider, GstStructure *
|
||||
tensors);
|
||||
GstOnnxExecutionProvider provider);
|
||||
bool hasSession(void);
|
||||
void destroySession(void);
|
||||
void setInputImageFormat(GstMlInputImageFormat format);
|
||||
GstMlInputImageFormat getInputImageFormat(void);
|
||||
GstTensorDataType getInputImageDatatype(void);
|
||||
@ -85,6 +83,7 @@ namespace GstOnnxNamespace {
|
||||
GstTensorMeta *copy_tensors_to_meta (std::vector<Ort::Value> &outputs,
|
||||
GstBuffer *buffer);
|
||||
void parseDimensions(GstVideoInfo vinfo);
|
||||
void setProviderConfig(const char *config);
|
||||
private:
|
||||
|
||||
GstElement *debug_parent;
|
||||
@ -93,7 +92,6 @@ namespace GstOnnxNamespace {
|
||||
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
||||
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
|
||||
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
|
||||
bool setTensorDescDatatype (ONNXTensorElementDataType dt, GstStructure * tensor_desc);
|
||||
Ort::Env env;
|
||||
Ort::Session * session;
|
||||
int32_t width;
|
||||
@ -112,6 +110,7 @@ namespace GstOnnxNamespace {
|
||||
bool fixedInputImageSize;
|
||||
float inputTensorOffset;
|
||||
float inputTensorScale;
|
||||
const char *provider_config;
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -72,6 +72,7 @@
|
||||
* @optimization_level: ONNX session optimization level
|
||||
* @execution_provider: ONNX execution provider
|
||||
* @onnx_client opaque pointer to ONNX client
|
||||
* @onnx_disabled true if inference is disabled
|
||||
* @video_info @ref GstVideoInfo of sink caps
|
||||
*/
|
||||
struct _GstOnnxInference
|
||||
@ -81,8 +82,9 @@ struct _GstOnnxInference
|
||||
GstOnnxOptimizationLevel optimization_level;
|
||||
GstOnnxExecutionProvider execution_provider;
|
||||
gpointer onnx_client;
|
||||
gboolean onnx_disabled;
|
||||
GstVideoInfo video_info;
|
||||
GstStructure *tensors;
|
||||
gchar *provider_config;
|
||||
};
|
||||
|
||||
GST_DEBUG_CATEGORY (onnx_inference_debug);
|
||||
@ -100,6 +102,7 @@ enum
|
||||
PROP_INPUT_IMAGE_FORMAT,
|
||||
PROP_OPTIMIZATION_LEVEL,
|
||||
PROP_EXECUTION_PROVIDER,
|
||||
PROP_PROVIDER_CONFIG,
|
||||
PROP_INPUT_OFFSET,
|
||||
PROP_INPUT_SCALE
|
||||
};
|
||||
@ -130,13 +133,12 @@ static GstFlowReturn gst_onnx_inference_transform_ip (GstBaseTransform *
|
||||
trans, GstBuffer * buf);
|
||||
static gboolean gst_onnx_inference_process (GstBaseTransform * trans,
|
||||
GstBuffer * buf);
|
||||
static gboolean gst_onnx_inference_create_session (GstBaseTransform * trans);
|
||||
static GstCaps *gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
|
||||
static gboolean
|
||||
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||
GstCaps * outcaps);
|
||||
static gboolean gst_onnx_inference_start (GstBaseTransform * trans);
|
||||
static gboolean gst_onnx_inference_stop (GstBaseTransform * trans);
|
||||
|
||||
G_DEFINE_TYPE (GstOnnxInference, gst_onnx_inference, GST_TYPE_BASE_TRANSFORM);
|
||||
|
||||
@ -187,24 +189,12 @@ gst_onnx_execution_provider_get_type (void)
|
||||
static GEnumValue execution_provider_types[] = {
|
||||
{GST_ONNX_EXECUTION_PROVIDER_CPU, "CPU execution provider",
|
||||
"cpu"},
|
||||
#if HAVE_CUDA
|
||||
{GST_ONNX_EXECUTION_PROVIDER_CUDA,
|
||||
"CUDA execution provider",
|
||||
"cuda"},
|
||||
#else
|
||||
{GST_ONNX_EXECUTION_PROVIDER_CUDA,
|
||||
"CUDA execution provider (compiled out, will use CPU)",
|
||||
"cuda"},
|
||||
#endif
|
||||
#ifdef HAVE_VSI_NPU
|
||||
{GST_ONNX_EXECUTION_PROVIDER_VSI,
|
||||
"VeriSilicon NPU execution provider",
|
||||
"vsi"},
|
||||
#else
|
||||
{GST_ONNX_EXECUTION_PROVIDER_VSI,
|
||||
"VeriSilicon NPU execution provider (compiled out, will use CPU)",
|
||||
"vsi"},
|
||||
#endif
|
||||
{GST_ONNX_EXECUTION_PROVIDER_OPENVINO,
|
||||
"OPENVINO execution provider",
|
||||
"openvino"},
|
||||
{0, NULL, NULL},
|
||||
};
|
||||
|
||||
@ -331,6 +321,14 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||
G_MINFLOAT, G_MAXFLOAT, 1.0,
|
||||
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_PROVIDER_CONFIG,
|
||||
g_param_spec_string ("provider-config",
|
||||
"Provider config",
|
||||
"Comma-separierte Key=Value-Optionen",
|
||||
nullptr,
|
||||
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
|
||||
gst_element_class_set_static_metadata (element_class, "onnxinference",
|
||||
"Filter/Effect/Video",
|
||||
@ -346,10 +344,6 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||
GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_caps);
|
||||
basetransform_class->set_caps =
|
||||
GST_DEBUG_FUNCPTR (gst_onnx_inference_set_caps);
|
||||
basetransform_class->start =
|
||||
GST_DEBUG_FUNCPTR(gst_onnx_inference_start);
|
||||
basetransform_class->stop =
|
||||
GST_DEBUG_FUNCPTR(gst_onnx_inference_stop);
|
||||
|
||||
gst_type_mark_as_plugin_api (GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
|
||||
(GstPluginAPIFlags) 0);
|
||||
@ -363,18 +357,16 @@ static void
|
||||
gst_onnx_inference_init (GstOnnxInference * self)
|
||||
{
|
||||
self->onnx_client = new GstOnnxNamespace::GstOnnxClient (GST_ELEMENT(self));
|
||||
/* TODO: at the moment onnx inference only support video output. We
|
||||
* should revisit this once we generalize this aspect */
|
||||
self->tensors = gst_structure_new_empty ("video/x-raw");
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_inference_finalize (GObject * object)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (object);
|
||||
|
||||
g_free (self->provider_config);
|
||||
self->provider_config = NULL;
|
||||
g_free (self->model_file);
|
||||
gst_structure_free(self->tensors);
|
||||
delete GST_ONNX_CLIENT_MEMBER (self);
|
||||
G_OBJECT_CLASS (gst_onnx_inference_parent_class)->finalize (object);
|
||||
}
|
||||
@ -396,6 +388,7 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
|
||||
if (self->model_file)
|
||||
g_free (self->model_file);
|
||||
self->model_file = g_strdup (filename);
|
||||
self->onnx_disabled = FALSE;
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
|
||||
}
|
||||
@ -418,6 +411,11 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
|
||||
case PROP_INPUT_SCALE:
|
||||
onnxClient->setInputImageScale (g_value_get_float (value));
|
||||
break;
|
||||
case PROP_PROVIDER_CONFIG:
|
||||
g_free (self->provider_config);
|
||||
self->provider_config = g_value_dup_string (value);
|
||||
onnxClient->setProviderConfig(self->provider_config);
|
||||
break;
|
||||
default:
|
||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||
break;
|
||||
@ -456,6 +454,45 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id,
|
||||
}
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_create_session (GstBaseTransform * trans)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
|
||||
GST_OBJECT_LOCK (self);
|
||||
if (self->onnx_disabled) {
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return FALSE;
|
||||
}
|
||||
if (onnxClient->hasSession ()) {
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
if (self->model_file) {
|
||||
gboolean ret =
|
||||
GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
|
||||
self->optimization_level,
|
||||
self->execution_provider);
|
||||
if (!ret) {
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Unable to create ONNX session. Model is disabled.");
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
} else {
|
||||
self->onnx_disabled = TRUE;
|
||||
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL), ("Model file not found"));
|
||||
}
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
if (self->onnx_disabled) {
|
||||
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
|
||||
}
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static GstCaps *
|
||||
gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
|
||||
@ -464,17 +501,9 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
GstCaps *other_caps;
|
||||
GstCaps *restrictions;
|
||||
bool has_session;
|
||||
|
||||
GST_OBJECT_LOCK (self);
|
||||
has_session = onnxClient->hasSession ();
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
if (!has_session) {
|
||||
other_caps = gst_caps_ref (caps);
|
||||
goto done;
|
||||
}
|
||||
|
||||
if (!gst_onnx_inference_create_session (trans))
|
||||
return NULL;
|
||||
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
|
||||
|
||||
if (gst_base_transform_is_passthrough (trans))
|
||||
@ -528,20 +557,10 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
GST_DEBUG_OBJECT(self, "Applying caps restrictions: %" GST_PTR_FORMAT,
|
||||
restrictions);
|
||||
|
||||
if (direction == GST_PAD_SINK) {
|
||||
GstCaps * tensors_caps = gst_caps_new_full (gst_structure_copy (
|
||||
self->tensors), NULL);
|
||||
GstCaps *intersect = gst_caps_intersect (restrictions, tensors_caps);
|
||||
gst_caps_replace (&restrictions, intersect);
|
||||
gst_caps_unref (tensors_caps);
|
||||
gst_caps_unref (intersect);
|
||||
}
|
||||
|
||||
other_caps = gst_caps_intersect_full (caps, restrictions,
|
||||
GST_CAPS_INTERSECT_FIRST);
|
||||
gst_caps_unref (restrictions);
|
||||
|
||||
done:
|
||||
if (filter_caps) {
|
||||
GstCaps *tmp = gst_caps_intersect_full (
|
||||
other_caps, filter_caps, GST_CAPS_INTERSECT_FIRST);
|
||||
@ -552,53 +571,6 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
return other_caps;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_start (GstBaseTransform * trans)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
gboolean ret = FALSE;
|
||||
|
||||
GST_OBJECT_LOCK (self);
|
||||
if (onnxClient->hasSession ()) {
|
||||
ret = TRUE;
|
||||
goto done;
|
||||
}
|
||||
|
||||
if (self->model_file == NULL) {
|
||||
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
|
||||
("model-file property not set"));
|
||||
goto done;
|
||||
}
|
||||
|
||||
ret = GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
|
||||
self->optimization_level,
|
||||
self->execution_provider,
|
||||
self->tensors);
|
||||
if (!ret)
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Unable to create ONNX session. Model is disabled.");
|
||||
|
||||
done:
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_stop (GstBaseTransform * trans)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
|
||||
GST_OBJECT_LOCK (self);
|
||||
if (onnxClient->hasSession ())
|
||||
GST_ONNX_CLIENT_MEMBER (self)->destroySession ();
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||
GstCaps * outcaps)
|
||||
|
Loading…
x
Reference in New Issue
Block a user