onnxinference: Clean up session creation logic
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/9176>
This commit is contained in:
parent
82a71a7739
commit
f48ad0fde6
@ -72,7 +72,6 @@
|
||||
* @optimization_level: ONNX session optimization level
|
||||
* @execution_provider: ONNX execution provider
|
||||
* @onnx_client opaque pointer to ONNX client
|
||||
* @onnx_disabled true if inference is disabled
|
||||
* @video_info @ref GstVideoInfo of sink caps
|
||||
*/
|
||||
struct _GstOnnxInference
|
||||
@ -82,7 +81,6 @@ struct _GstOnnxInference
|
||||
GstOnnxOptimizationLevel optimization_level;
|
||||
GstOnnxExecutionProvider execution_provider;
|
||||
gpointer onnx_client;
|
||||
gboolean onnx_disabled;
|
||||
GstVideoInfo video_info;
|
||||
GstStructure *tensors;
|
||||
};
|
||||
@ -132,12 +130,12 @@ static GstFlowReturn gst_onnx_inference_transform_ip (GstBaseTransform *
|
||||
trans, GstBuffer * buf);
|
||||
static gboolean gst_onnx_inference_process (GstBaseTransform * trans,
|
||||
GstBuffer * buf);
|
||||
static gboolean gst_onnx_inference_create_session (GstBaseTransform * trans);
|
||||
static GstCaps *gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
|
||||
static gboolean
|
||||
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||
GstCaps * outcaps);
|
||||
static gboolean gst_onnx_inference_start (GstBaseTransform * trans);
|
||||
|
||||
G_DEFINE_TYPE (GstOnnxInference, gst_onnx_inference, GST_TYPE_BASE_TRANSFORM);
|
||||
|
||||
@ -332,6 +330,8 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||
GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_caps);
|
||||
basetransform_class->set_caps =
|
||||
GST_DEBUG_FUNCPTR (gst_onnx_inference_set_caps);
|
||||
basetransform_class->start =
|
||||
GST_DEBUG_FUNCPTR(gst_onnx_inference_start);
|
||||
|
||||
gst_type_mark_as_plugin_api (GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
|
||||
(GstPluginAPIFlags) 0);
|
||||
@ -345,8 +345,6 @@ static void
|
||||
gst_onnx_inference_init (GstOnnxInference * self)
|
||||
{
|
||||
self->onnx_client = new GstOnnxNamespace::GstOnnxClient (GST_ELEMENT(self));
|
||||
self->onnx_disabled = TRUE;
|
||||
|
||||
/* TODO: at the moment onnx inference only support video output. We
|
||||
* should revisit this once we generalize this aspect */
|
||||
self->tensors = gst_structure_new_empty ("video/x-raw");
|
||||
@ -380,7 +378,6 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
|
||||
if (self->model_file)
|
||||
g_free (self->model_file);
|
||||
self->model_file = g_strdup (filename);
|
||||
self->onnx_disabled = FALSE;
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
|
||||
}
|
||||
@ -441,45 +438,6 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id,
|
||||
}
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_create_session (GstBaseTransform * trans)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
|
||||
GST_OBJECT_LOCK (self);
|
||||
if (self->onnx_disabled) {
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return FALSE;
|
||||
}
|
||||
if (onnxClient->hasSession ()) {
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
if (self->model_file) {
|
||||
gboolean ret =
|
||||
GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
|
||||
self->optimization_level,
|
||||
self->execution_provider, self->tensors);
|
||||
if (!ret) {
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Unable to create ONNX session. Model is disabled.");
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
} else {
|
||||
self->onnx_disabled = TRUE;
|
||||
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL), ("Model file not found"));
|
||||
}
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
if (self->onnx_disabled) {
|
||||
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
|
||||
}
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static GstCaps *
|
||||
gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
|
||||
@ -488,9 +446,17 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
GstCaps *other_caps;
|
||||
GstCaps *restrictions;
|
||||
bool has_session;
|
||||
|
||||
GST_OBJECT_LOCK (self);
|
||||
has_session = onnxClient->hasSession ();
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
if (!has_session) {
|
||||
other_caps = gst_caps_ref (caps);
|
||||
goto done;
|
||||
}
|
||||
|
||||
if (!gst_onnx_inference_create_session (trans))
|
||||
return NULL;
|
||||
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
|
||||
|
||||
if (gst_base_transform_is_passthrough (trans))
|
||||
@ -557,6 +523,7 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
GST_CAPS_INTERSECT_FIRST);
|
||||
gst_caps_unref (restrictions);
|
||||
|
||||
done:
|
||||
if (filter_caps) {
|
||||
GstCaps *tmp = gst_caps_intersect_full (
|
||||
other_caps, filter_caps, GST_CAPS_INTERSECT_FIRST);
|
||||
@ -567,6 +534,40 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
return other_caps;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_start (GstBaseTransform * trans)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
gboolean ret = FALSE;
|
||||
|
||||
GST_OBJECT_LOCK (self);
|
||||
if (onnxClient->hasSession ()) {
|
||||
ret = TRUE;
|
||||
goto done;
|
||||
}
|
||||
|
||||
if (self->model_file == NULL) {
|
||||
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
|
||||
("model-file property not set"));
|
||||
goto done;
|
||||
}
|
||||
|
||||
ret = GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
|
||||
self->optimization_level,
|
||||
self->execution_provider,
|
||||
self->tensors);
|
||||
if (!ret)
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Unable to create ONNX session. Model is disabled.");
|
||||
|
||||
done:
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||
GstCaps * outcaps)
|
||||
|
Loading…
x
Reference in New Issue
Block a user