onnxinference: Clean up session creation logic

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/9176>
This commit is contained in:
Olivier Crête 2025-03-05 17:47:41 -05:00 committed by GStreamer Marge Bot
parent 82a71a7739
commit f48ad0fde6

View File

@ -72,7 +72,6 @@
* @optimization_level: ONNX session optimization level
* @execution_provider: ONNX execution provider
* @onnx_client opaque pointer to ONNX client
* @onnx_disabled true if inference is disabled
* @video_info @ref GstVideoInfo of sink caps
*/
struct _GstOnnxInference
@ -82,7 +81,6 @@ struct _GstOnnxInference
GstOnnxOptimizationLevel optimization_level;
GstOnnxExecutionProvider execution_provider;
gpointer onnx_client;
gboolean onnx_disabled;
GstVideoInfo video_info;
GstStructure *tensors;
};
@ -132,12 +130,12 @@ static GstFlowReturn gst_onnx_inference_transform_ip (GstBaseTransform *
trans, GstBuffer * buf);
static gboolean gst_onnx_inference_process (GstBaseTransform * trans,
GstBuffer * buf);
static gboolean gst_onnx_inference_create_session (GstBaseTransform * trans);
static GstCaps *gst_onnx_inference_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
static gboolean
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps);
static gboolean gst_onnx_inference_start (GstBaseTransform * trans);
G_DEFINE_TYPE (GstOnnxInference, gst_onnx_inference, GST_TYPE_BASE_TRANSFORM);
@ -332,6 +330,8 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_caps);
basetransform_class->set_caps =
GST_DEBUG_FUNCPTR (gst_onnx_inference_set_caps);
basetransform_class->start =
GST_DEBUG_FUNCPTR(gst_onnx_inference_start);
gst_type_mark_as_plugin_api (GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
(GstPluginAPIFlags) 0);
@ -345,8 +345,6 @@ static void
gst_onnx_inference_init (GstOnnxInference * self)
{
self->onnx_client = new GstOnnxNamespace::GstOnnxClient (GST_ELEMENT(self));
self->onnx_disabled = TRUE;
/* TODO: at the moment onnx inference only support video output. We
* should revisit this once we generalize this aspect */
self->tensors = gst_structure_new_empty ("video/x-raw");
@ -380,7 +378,6 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
if (self->model_file)
g_free (self->model_file);
self->model_file = g_strdup (filename);
self->onnx_disabled = FALSE;
} else {
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
}
@ -441,45 +438,6 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id,
}
}
static gboolean
gst_onnx_inference_create_session (GstBaseTransform * trans)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
GST_OBJECT_LOCK (self);
if (self->onnx_disabled) {
GST_OBJECT_UNLOCK (self);
return FALSE;
}
if (onnxClient->hasSession ()) {
GST_OBJECT_UNLOCK (self);
return TRUE;
}
if (self->model_file) {
gboolean ret =
GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
self->optimization_level,
self->execution_provider, self->tensors);
if (!ret) {
GST_ERROR_OBJECT (self,
"Unable to create ONNX session. Model is disabled.");
self->onnx_disabled = TRUE;
}
} else {
self->onnx_disabled = TRUE;
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL), ("Model file not found"));
}
GST_OBJECT_UNLOCK (self);
if (self->onnx_disabled) {
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
}
return TRUE;
}
static GstCaps *
gst_onnx_inference_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
@ -488,9 +446,17 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
GstCaps *other_caps;
GstCaps *restrictions;
bool has_session;
GST_OBJECT_LOCK (self);
has_session = onnxClient->hasSession ();
GST_OBJECT_UNLOCK (self);
if (!has_session) {
other_caps = gst_caps_ref (caps);
goto done;
}
if (!gst_onnx_inference_create_session (trans))
return NULL;
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
if (gst_base_transform_is_passthrough (trans))
@ -557,6 +523,7 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
GST_CAPS_INTERSECT_FIRST);
gst_caps_unref (restrictions);
done:
if (filter_caps) {
GstCaps *tmp = gst_caps_intersect_full (
other_caps, filter_caps, GST_CAPS_INTERSECT_FIRST);
@ -567,6 +534,40 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
return other_caps;
}
static gboolean
gst_onnx_inference_start (GstBaseTransform * trans)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
gboolean ret = FALSE;
GST_OBJECT_LOCK (self);
if (onnxClient->hasSession ()) {
ret = TRUE;
goto done;
}
if (self->model_file == NULL) {
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
("model-file property not set"));
goto done;
}
ret = GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
self->optimization_level,
self->execution_provider,
self->tensors);
if (!ret)
GST_ERROR_OBJECT (self,
"Unable to create ONNX session. Model is disabled.");
done:
GST_OBJECT_UNLOCK (self);
return ret;
}
static gboolean
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps)