From 5e1291fd86b636689ee012eef0cd79505495c6bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Cr=C3=AAte?= Date: Wed, 24 Jan 2024 22:31:21 -0500 Subject: [PATCH] onnx: Only read labels file one and use GIO Part-of: --- .../onnx/decoders/gstobjectdetectorutils.cpp | 78 +++++++++++++------ .../onnx/decoders/gstobjectdetectorutils.h | 7 +- .../onnx/decoders/gstssdobjectdetector.cpp | 25 +++--- .../ext/onnx/decoders/gstssdobjectdetector.h | 1 + .../gst-plugins-bad/ext/onnx/meson.build | 2 +- 5 files changed, 78 insertions(+), 35 deletions(-) diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp index 54221f7095..dbd4b30843 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp @@ -22,7 +22,56 @@ #include "gstobjectdetectorutils.h" -#include +#include + + +char ** +read_labels (const char * labels_file) +{ + GPtrArray *array; + GFile *file = g_file_new_for_path (labels_file); + GFileInputStream *file_stream; + GDataInputStream *data_stream; + GError *error = NULL; + gchar *line; + + file_stream = g_file_read (file, NULL, &error); + g_object_unref (file); + if (!file_stream) { + GST_WARNING ("Could not open file %s: %s\n", labels_file, + error->message); + g_clear_error (&error); + return NULL; + } + + data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream)); + g_object_unref (file_stream); + + array = g_ptr_array_new(); + + while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL, + &error))) + g_ptr_array_add (array, line); + + g_object_unref (data_stream); + + if (error) { + GST_WARNING ("Could not open file %s: %s", labels_file, error->message); + g_ptr_array_free (array, TRUE); + g_clear_error (&error); + return NULL; + } + + if (array->len == 0) { + g_ptr_array_free (array, TRUE); + return NULL; + } + + g_ptr_array_add (array, NULL); + + return (char **) g_ptr_array_free (array, FALSE); +} + GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0, float _y0, float _width, float _height): @@ -47,20 +96,9 @@ namespace GstObjectDetectorUtils { } - std::vector < std::string > - GstObjectDetectorUtils::ReadLabels (const std::string & labelsFile) - { - std::vector < std::string > labels; - std::string line; - std::ifstream fp (labelsFile); - while (std::getline (fp, line)) - labels.push_back (line); - - return labels; - } std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w, - int32_t h, GstTensorMeta * tmeta, std::string labelPath, + int32_t h, GstTensorMeta * tmeta, gchar **labels, float scoreThreshold) { @@ -72,18 +110,17 @@ namespace GstObjectDetectorUtils } auto type = tmeta->tensor[classIndex].type; return (type == GST_TENSOR_TYPE_FLOAT32) ? - doRun < float >(w, h, tmeta, labelPath, scoreThreshold) - : doRun < int >(w, h, tmeta, labelPath, scoreThreshold); + doRun < float >(w, h, tmeta, labels, scoreThreshold) + : doRun < int >(w, h, tmeta, labels, scoreThreshold); } template < typename T > std::vector < GstMlBoundingBox > GstObjectDetectorUtils::doRun (int32_t w, int32_t h, - GstTensorMeta * tmeta, std::string labelPath, float scoreThreshold) + GstTensorMeta * tmeta, char **labels, float scoreThreshold) { std::vector < GstMlBoundingBox > boundingBoxes; GstMapInfo map_info[GstObjectDetectorMaxNodes]; GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL }; - std::vector < std::string > labels; gint index; T *numDetections = nullptr, *bboxes = nullptr, *scores = nullptr, *labelIndex = nullptr; @@ -162,15 +199,12 @@ namespace GstObjectDetectorUtils labelIndex = (T *) map_info[index].data; } - if (!labelPath.empty ()) - labels = ReadLabels (labelPath); - for (int i = 0; i < numDetections[0]; ++i) { if (scores[i] > scoreThreshold) { std::string label = ""; - if (labelIndex && !labels.empty ()) - label = labels[labelIndex[i] - 1]; + if (labels && labelIndex) + label = labels[(int)labelIndex[i] - 1]; auto score = scores[i]; auto y0 = bboxes[i * 4] * h; auto x0 = bboxes[i * 4 + 1] * w; diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h index 5668ec6b23..2e7a83995a 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h +++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h @@ -29,6 +29,8 @@ #include "gstml.h" #include "tensor/gsttensormeta.h" +char ** read_labels (const char * labels_file); + /* Object detection tensor id strings */ #define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes" #define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores" @@ -68,14 +70,13 @@ namespace GstObjectDetectorUtils { ~GstObjectDetectorUtils(void) = default; std::vector < GstMlBoundingBox > run(int32_t w, int32_t h, GstTensorMeta *tmeta, - std::string labelPath, + char **labels, float scoreThreshold); private: template < typename T > std::vector < GstMlBoundingBox > doRun(int32_t w, int32_t h, - GstTensorMeta *tmeta, std::string labelPath, + GstTensorMeta *tmeta, char **labels, float scoreThreshold); - std::vector < std::string > ReadLabels(const std::string & labelsFile); }; } diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp index 1988104a88..5c69579323 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp @@ -164,6 +164,7 @@ gst_ssd_object_detector_finalize (GObject * object) GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object); g_free (self->label_file); + g_strfreev (self->labels); delete GST_ODUTILS_MEMBER (self); G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object); @@ -178,14 +179,20 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id, switch (prop_id) { case PROP_LABEL_FILE: - filename = g_value_get_string (value); - if (filename - && g_file_test (filename, - (GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) { - g_free (self->label_file); - self->label_file = g_strdup (filename); - } else { - GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename); + { + gchar **labels; + + filename = g_value_get_string (value); + labels = read_labels (filename); + + if (labels) { + g_free (self->label_file); + self->label_file = g_strdup (filename); + g_strfreev (self->labels); + self->labels = labels; + } else { + GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename); + } } break; case PROP_SCORE_THRESHOLD: @@ -313,7 +320,7 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf) std::vector < GstMlBoundingBox > boxes = GST_ODUTILS_MEMBER (self)->run (self->video_info.width, - self->video_info.height, tmeta, self->label_file ? self->label_file : "", + self->video_info.height, tmeta, self->labels, self->score_threshold); for (auto & b:boxes) { diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h index 4549ad4c92..b101650ba4 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h +++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h @@ -52,6 +52,7 @@ struct _GstSsdObjectDetector { GstBaseTransform basetransform; gchar *label_file; + gchar **labels; gfloat score_threshold; gfloat confidence_threshold; gfloat iou_threshold; diff --git a/subprojects/gst-plugins-bad/ext/onnx/meson.build b/subprojects/gst-plugins-bad/ext/onnx/meson.build index f2d40006bf..11f93eaa3f 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/meson.build +++ b/subprojects/gst-plugins-bad/ext/onnx/meson.build @@ -25,7 +25,7 @@ if onnxrt_dep.found() link_args : noseh_link_args, include_directories : [configinc, libsinc, cuda_stubinc], dependencies : [gstbase_dep, gstvideo_dep, gstanalytics_dep, onnxrt_dep, - libm] + extra_deps, + libm, gio_dep] + extra_deps, install : true, install_dir : plugins_install_dir, )