diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp index a8600d2052..f8df0f3374 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp @@ -203,56 +203,65 @@ bool GstOnnxClient::createSession (std::string modelFile, break; }; - Ort::SessionOptions sessionOptions; - // for debugging - //sessionOptions.SetIntraOpNumThreads (1); - sessionOptions.SetGraphOptimizationLevel (onnx_optim); - m_provider = provider; - switch (m_provider) { - case GST_ONNX_EXECUTION_PROVIDER_CUDA: + try { + Ort::SessionOptions sessionOptions; + // for debugging + //sessionOptions.SetIntraOpNumThreads (1); + sessionOptions.SetGraphOptimizationLevel (onnx_optim); + m_provider = provider; + switch (m_provider) { + case GST_ONNX_EXECUTION_PROVIDER_CUDA: #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA - Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA - (sessionOptions, 0)); + Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA + (sessionOptions, 0)); #else - return false; + GST_ERROR ("ONNX CUDA execution provider not supported"); + return false; #endif - break; - default: - break; + break; + default: + break; - }; - session = new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions); - auto inputTypeInfo = session->GetInputTypeInfo (0); - std::vector < int64_t > inputDims = - inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); - if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { - height = inputDims[1]; - width = inputDims[2]; - channels = inputDims[3]; - } else { - channels = inputDims[1]; - height = inputDims[2]; - width = inputDims[3]; - } - - fixedInputImageSize = width > 0 && height > 0; - GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ()); - - Ort::AllocatorWithDefaultOptions allocator; - auto input_name = session->GetInputNameAllocated (0, allocator); - GST_DEBUG ("Input name: %s", input_name.get()); - - for (size_t i = 0; i < session->GetOutputCount (); ++i) { - auto output_name = session->GetOutputNameAllocated (i, allocator); - GST_DEBUG("Output name %lu:%s", i, output_name.get()); - outputNames.push_back (std::move(output_name)); - auto type_info = session->GetOutputTypeInfo (i); - auto tensor_info = type_info.GetTensorTypeAndShapeInfo (); - - if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) { - auto function = outputNodeIndexToFunction[i]; - outputNodeInfo[function].type = tensor_info.GetElementType (); + }; + session = + new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions); + auto inputTypeInfo = session->GetInputTypeInfo (0); + std::vector < int64_t > inputDims = + inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); + if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { + height = inputDims[1]; + width = inputDims[2]; + channels = inputDims[3]; + } else { + channels = inputDims[1]; + height = inputDims[2]; + width = inputDims[3]; } + + fixedInputImageSize = width > 0 && height > 0; + GST_DEBUG ("Number of Output Nodes: %d", + (gint) session->GetOutputCount ()); + + Ort::AllocatorWithDefaultOptions allocator; + auto input_name = session->GetInputNameAllocated (0, allocator); + GST_DEBUG ("Input name: %s", input_name.get ()); + + for (size_t i = 0; i < session->GetOutputCount (); ++i) { + auto output_name = session->GetOutputNameAllocated (i, allocator); + GST_DEBUG ("Output name %lu:%s", i, output_name.get ()); + outputNames.push_back (std::move (output_name)); + auto type_info = session->GetOutputTypeInfo (i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo (); + + if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) { + auto function = outputNodeIndexToFunction[i]; + outputNodeInfo[function].type = tensor_info.GetElementType (); + } + } + } + catch (Ort::Exception & ortex) { + GST_ERROR ("%s", ortex.what ()); + return false; } return true; diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp index 720d8236b4..c86bd40205 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp @@ -643,9 +643,15 @@ gst_onnx_object_detector_process (GstBaseTransform * trans, GstBuffer * buf) } if (gst_buffer_map (buf, &info, GST_MAP_READ)) { GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans); - auto boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta, - self->label_file ? self->label_file : "", - self->score_threshold); + std::vector < GstOnnxNamespace::GstMlBoundingBox > boxes; + try { + boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta, + self->label_file ? self->label_file : "", self->score_threshold); + } + catch (Ort::Exception & ortex) { + GST_ERROR_OBJECT (self, "%s", ortex.what ()); + return FALSE; + } for (auto & b:boxes) { auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf, GST_ONNX_OBJECT_DETECTOR_META_NAME,