onnx: Only read labels file one and use GIO

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6001>
This commit is contained in:
Olivier Crête 2024-01-24 22:31:21 -05:00
parent 13de5160be
commit 5e1291fd86
5 changed files with 78 additions and 35 deletions

View File

@ -22,7 +22,56 @@
#include "gstobjectdetectorutils.h" #include "gstobjectdetectorutils.h"
#include <fstream> #include <gio/gio.h>
char **
read_labels (const char * labels_file)
{
GPtrArray *array;
GFile *file = g_file_new_for_path (labels_file);
GFileInputStream *file_stream;
GDataInputStream *data_stream;
GError *error = NULL;
gchar *line;
file_stream = g_file_read (file, NULL, &error);
g_object_unref (file);
if (!file_stream) {
GST_WARNING ("Could not open file %s: %s\n", labels_file,
error->message);
g_clear_error (&error);
return NULL;
}
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
g_object_unref (file_stream);
array = g_ptr_array_new();
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
&error)))
g_ptr_array_add (array, line);
g_object_unref (data_stream);
if (error) {
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
g_ptr_array_free (array, TRUE);
g_clear_error (&error);
return NULL;
}
if (array->len == 0) {
g_ptr_array_free (array, TRUE);
return NULL;
}
g_ptr_array_add (array, NULL);
return (char **) g_ptr_array_free (array, FALSE);
}
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0, GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
float _y0, float _width, float _height): float _y0, float _width, float _height):
@ -47,20 +96,9 @@ namespace GstObjectDetectorUtils
{ {
} }
std::vector < std::string >
GstObjectDetectorUtils::ReadLabels (const std::string & labelsFile)
{
std::vector < std::string > labels;
std::string line;
std::ifstream fp (labelsFile);
while (std::getline (fp, line))
labels.push_back (line);
return labels;
}
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w, std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
int32_t h, GstTensorMeta * tmeta, std::string labelPath, int32_t h, GstTensorMeta * tmeta, gchar **labels,
float scoreThreshold) float scoreThreshold)
{ {
@ -72,18 +110,17 @@ namespace GstObjectDetectorUtils
} }
auto type = tmeta->tensor[classIndex].type; auto type = tmeta->tensor[classIndex].type;
return (type == GST_TENSOR_TYPE_FLOAT32) ? return (type == GST_TENSOR_TYPE_FLOAT32) ?
doRun < float >(w, h, tmeta, labelPath, scoreThreshold) doRun < float >(w, h, tmeta, labels, scoreThreshold)
: doRun < int >(w, h, tmeta, labelPath, scoreThreshold); : doRun < int >(w, h, tmeta, labels, scoreThreshold);
} }
template < typename T > std::vector < GstMlBoundingBox > template < typename T > std::vector < GstMlBoundingBox >
GstObjectDetectorUtils::doRun (int32_t w, int32_t h, GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
GstTensorMeta * tmeta, std::string labelPath, float scoreThreshold) GstTensorMeta * tmeta, char **labels, float scoreThreshold)
{ {
std::vector < GstMlBoundingBox > boundingBoxes; std::vector < GstMlBoundingBox > boundingBoxes;
GstMapInfo map_info[GstObjectDetectorMaxNodes]; GstMapInfo map_info[GstObjectDetectorMaxNodes];
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL }; GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
std::vector < std::string > labels;
gint index; gint index;
T *numDetections = nullptr, *bboxes = nullptr, *scores = T *numDetections = nullptr, *bboxes = nullptr, *scores =
nullptr, *labelIndex = nullptr; nullptr, *labelIndex = nullptr;
@ -162,15 +199,12 @@ namespace GstObjectDetectorUtils
labelIndex = (T *) map_info[index].data; labelIndex = (T *) map_info[index].data;
} }
if (!labelPath.empty ())
labels = ReadLabels (labelPath);
for (int i = 0; i < numDetections[0]; ++i) { for (int i = 0; i < numDetections[0]; ++i) {
if (scores[i] > scoreThreshold) { if (scores[i] > scoreThreshold) {
std::string label = ""; std::string label = "";
if (labelIndex && !labels.empty ()) if (labels && labelIndex)
label = labels[labelIndex[i] - 1]; label = labels[(int)labelIndex[i] - 1];
auto score = scores[i]; auto score = scores[i];
auto y0 = bboxes[i * 4] * h; auto y0 = bboxes[i * 4] * h;
auto x0 = bboxes[i * 4 + 1] * w; auto x0 = bboxes[i * 4 + 1] * w;

View File

@ -29,6 +29,8 @@
#include "gstml.h" #include "gstml.h"
#include "tensor/gsttensormeta.h" #include "tensor/gsttensormeta.h"
char ** read_labels (const char * labels_file);
/* Object detection tensor id strings */ /* Object detection tensor id strings */
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes" #define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores" #define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
@ -68,14 +70,13 @@ namespace GstObjectDetectorUtils {
~GstObjectDetectorUtils(void) = default; ~GstObjectDetectorUtils(void) = default;
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h, std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
GstTensorMeta *tmeta, GstTensorMeta *tmeta,
std::string labelPath, char **labels,
float scoreThreshold); float scoreThreshold);
private: private:
template < typename T > std::vector < GstMlBoundingBox > template < typename T > std::vector < GstMlBoundingBox >
doRun(int32_t w, int32_t h, doRun(int32_t w, int32_t h,
GstTensorMeta *tmeta, std::string labelPath, GstTensorMeta *tmeta, char **labels,
float scoreThreshold); float scoreThreshold);
std::vector < std::string > ReadLabels(const std::string & labelsFile);
}; };
} }

View File

@ -164,6 +164,7 @@ gst_ssd_object_detector_finalize (GObject * object)
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object); GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
g_free (self->label_file); g_free (self->label_file);
g_strfreev (self->labels);
delete GST_ODUTILS_MEMBER (self); delete GST_ODUTILS_MEMBER (self);
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object); G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
@ -178,15 +179,21 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
switch (prop_id) { switch (prop_id) {
case PROP_LABEL_FILE: case PROP_LABEL_FILE:
{
gchar **labels;
filename = g_value_get_string (value); filename = g_value_get_string (value);
if (filename labels = read_labels (filename);
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) { if (labels) {
g_free (self->label_file); g_free (self->label_file);
self->label_file = g_strdup (filename); self->label_file = g_strdup (filename);
g_strfreev (self->labels);
self->labels = labels;
} else { } else {
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename); GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
} }
}
break; break;
case PROP_SCORE_THRESHOLD: case PROP_SCORE_THRESHOLD:
GST_OBJECT_LOCK (self); GST_OBJECT_LOCK (self);
@ -313,7 +320,7 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
std::vector < GstMlBoundingBox > boxes = std::vector < GstMlBoundingBox > boxes =
GST_ODUTILS_MEMBER (self)->run (self->video_info.width, GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
self->video_info.height, tmeta, self->label_file ? self->label_file : "", self->video_info.height, tmeta, self->labels,
self->score_threshold); self->score_threshold);
for (auto & b:boxes) { for (auto & b:boxes) {

View File

@ -52,6 +52,7 @@ struct _GstSsdObjectDetector
{ {
GstBaseTransform basetransform; GstBaseTransform basetransform;
gchar *label_file; gchar *label_file;
gchar **labels;
gfloat score_threshold; gfloat score_threshold;
gfloat confidence_threshold; gfloat confidence_threshold;
gfloat iou_threshold; gfloat iou_threshold;

View File

@ -25,7 +25,7 @@ if onnxrt_dep.found()
link_args : noseh_link_args, link_args : noseh_link_args,
include_directories : [configinc, libsinc, cuda_stubinc], include_directories : [configinc, libsinc, cuda_stubinc],
dependencies : [gstbase_dep, gstvideo_dep, gstanalytics_dep, onnxrt_dep, dependencies : [gstbase_dep, gstvideo_dep, gstanalytics_dep, onnxrt_dep,
libm] + extra_deps, libm, gio_dep] + extra_deps,
install : true, install : true,
install_dir : plugins_install_dir, install_dir : plugins_install_dir,
) )