onnx: Only read labels file one and use GIO
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6001>
This commit is contained in:
parent
13de5160be
commit
5e1291fd86
@ -22,7 +22,56 @@
|
|||||||
|
|
||||||
#include "gstobjectdetectorutils.h"
|
#include "gstobjectdetectorutils.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <gio/gio.h>
|
||||||
|
|
||||||
|
|
||||||
|
char **
|
||||||
|
read_labels (const char * labels_file)
|
||||||
|
{
|
||||||
|
GPtrArray *array;
|
||||||
|
GFile *file = g_file_new_for_path (labels_file);
|
||||||
|
GFileInputStream *file_stream;
|
||||||
|
GDataInputStream *data_stream;
|
||||||
|
GError *error = NULL;
|
||||||
|
gchar *line;
|
||||||
|
|
||||||
|
file_stream = g_file_read (file, NULL, &error);
|
||||||
|
g_object_unref (file);
|
||||||
|
if (!file_stream) {
|
||||||
|
GST_WARNING ("Could not open file %s: %s\n", labels_file,
|
||||||
|
error->message);
|
||||||
|
g_clear_error (&error);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
|
||||||
|
g_object_unref (file_stream);
|
||||||
|
|
||||||
|
array = g_ptr_array_new();
|
||||||
|
|
||||||
|
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
|
||||||
|
&error)))
|
||||||
|
g_ptr_array_add (array, line);
|
||||||
|
|
||||||
|
g_object_unref (data_stream);
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
|
||||||
|
g_ptr_array_free (array, TRUE);
|
||||||
|
g_clear_error (&error);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (array->len == 0) {
|
||||||
|
g_ptr_array_free (array, TRUE);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
g_ptr_array_add (array, NULL);
|
||||||
|
|
||||||
|
return (char **) g_ptr_array_free (array, FALSE);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
|
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
|
||||||
float _y0, float _width, float _height):
|
float _y0, float _width, float _height):
|
||||||
@ -47,20 +96,9 @@ namespace GstObjectDetectorUtils
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector < std::string >
|
|
||||||
GstObjectDetectorUtils::ReadLabels (const std::string & labelsFile)
|
|
||||||
{
|
|
||||||
std::vector < std::string > labels;
|
|
||||||
std::string line;
|
|
||||||
std::ifstream fp (labelsFile);
|
|
||||||
while (std::getline (fp, line))
|
|
||||||
labels.push_back (line);
|
|
||||||
|
|
||||||
return labels;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
|
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
|
||||||
int32_t h, GstTensorMeta * tmeta, std::string labelPath,
|
int32_t h, GstTensorMeta * tmeta, gchar **labels,
|
||||||
float scoreThreshold)
|
float scoreThreshold)
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -72,18 +110,17 @@ namespace GstObjectDetectorUtils
|
|||||||
}
|
}
|
||||||
auto type = tmeta->tensor[classIndex].type;
|
auto type = tmeta->tensor[classIndex].type;
|
||||||
return (type == GST_TENSOR_TYPE_FLOAT32) ?
|
return (type == GST_TENSOR_TYPE_FLOAT32) ?
|
||||||
doRun < float >(w, h, tmeta, labelPath, scoreThreshold)
|
doRun < float >(w, h, tmeta, labels, scoreThreshold)
|
||||||
: doRun < int >(w, h, tmeta, labelPath, scoreThreshold);
|
: doRun < int >(w, h, tmeta, labels, scoreThreshold);
|
||||||
}
|
}
|
||||||
|
|
||||||
template < typename T > std::vector < GstMlBoundingBox >
|
template < typename T > std::vector < GstMlBoundingBox >
|
||||||
GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
|
GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
|
||||||
GstTensorMeta * tmeta, std::string labelPath, float scoreThreshold)
|
GstTensorMeta * tmeta, char **labels, float scoreThreshold)
|
||||||
{
|
{
|
||||||
std::vector < GstMlBoundingBox > boundingBoxes;
|
std::vector < GstMlBoundingBox > boundingBoxes;
|
||||||
GstMapInfo map_info[GstObjectDetectorMaxNodes];
|
GstMapInfo map_info[GstObjectDetectorMaxNodes];
|
||||||
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
|
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
|
||||||
std::vector < std::string > labels;
|
|
||||||
gint index;
|
gint index;
|
||||||
T *numDetections = nullptr, *bboxes = nullptr, *scores =
|
T *numDetections = nullptr, *bboxes = nullptr, *scores =
|
||||||
nullptr, *labelIndex = nullptr;
|
nullptr, *labelIndex = nullptr;
|
||||||
@ -162,15 +199,12 @@ namespace GstObjectDetectorUtils
|
|||||||
labelIndex = (T *) map_info[index].data;
|
labelIndex = (T *) map_info[index].data;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!labelPath.empty ())
|
|
||||||
labels = ReadLabels (labelPath);
|
|
||||||
|
|
||||||
for (int i = 0; i < numDetections[0]; ++i) {
|
for (int i = 0; i < numDetections[0]; ++i) {
|
||||||
if (scores[i] > scoreThreshold) {
|
if (scores[i] > scoreThreshold) {
|
||||||
std::string label = "";
|
std::string label = "";
|
||||||
|
|
||||||
if (labelIndex && !labels.empty ())
|
if (labels && labelIndex)
|
||||||
label = labels[labelIndex[i] - 1];
|
label = labels[(int)labelIndex[i] - 1];
|
||||||
auto score = scores[i];
|
auto score = scores[i];
|
||||||
auto y0 = bboxes[i * 4] * h;
|
auto y0 = bboxes[i * 4] * h;
|
||||||
auto x0 = bboxes[i * 4 + 1] * w;
|
auto x0 = bboxes[i * 4 + 1] * w;
|
||||||
|
@ -29,6 +29,8 @@
|
|||||||
#include "gstml.h"
|
#include "gstml.h"
|
||||||
#include "tensor/gsttensormeta.h"
|
#include "tensor/gsttensormeta.h"
|
||||||
|
|
||||||
|
char ** read_labels (const char * labels_file);
|
||||||
|
|
||||||
/* Object detection tensor id strings */
|
/* Object detection tensor id strings */
|
||||||
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
|
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
|
||||||
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
|
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
|
||||||
@ -68,14 +70,13 @@ namespace GstObjectDetectorUtils {
|
|||||||
~GstObjectDetectorUtils(void) = default;
|
~GstObjectDetectorUtils(void) = default;
|
||||||
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
|
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
|
||||||
GstTensorMeta *tmeta,
|
GstTensorMeta *tmeta,
|
||||||
std::string labelPath,
|
char **labels,
|
||||||
float scoreThreshold);
|
float scoreThreshold);
|
||||||
private:
|
private:
|
||||||
template < typename T > std::vector < GstMlBoundingBox >
|
template < typename T > std::vector < GstMlBoundingBox >
|
||||||
doRun(int32_t w, int32_t h,
|
doRun(int32_t w, int32_t h,
|
||||||
GstTensorMeta *tmeta, std::string labelPath,
|
GstTensorMeta *tmeta, char **labels,
|
||||||
float scoreThreshold);
|
float scoreThreshold);
|
||||||
std::vector < std::string > ReadLabels(const std::string & labelsFile);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,6 +164,7 @@ gst_ssd_object_detector_finalize (GObject * object)
|
|||||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
|
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
|
||||||
|
|
||||||
g_free (self->label_file);
|
g_free (self->label_file);
|
||||||
|
g_strfreev (self->labels);
|
||||||
delete GST_ODUTILS_MEMBER (self);
|
delete GST_ODUTILS_MEMBER (self);
|
||||||
|
|
||||||
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
|
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
|
||||||
@ -178,15 +179,21 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
|
|||||||
|
|
||||||
switch (prop_id) {
|
switch (prop_id) {
|
||||||
case PROP_LABEL_FILE:
|
case PROP_LABEL_FILE:
|
||||||
|
{
|
||||||
|
gchar **labels;
|
||||||
|
|
||||||
filename = g_value_get_string (value);
|
filename = g_value_get_string (value);
|
||||||
if (filename
|
labels = read_labels (filename);
|
||||||
&& g_file_test (filename,
|
|
||||||
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
|
if (labels) {
|
||||||
g_free (self->label_file);
|
g_free (self->label_file);
|
||||||
self->label_file = g_strdup (filename);
|
self->label_file = g_strdup (filename);
|
||||||
|
g_strfreev (self->labels);
|
||||||
|
self->labels = labels;
|
||||||
} else {
|
} else {
|
||||||
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case PROP_SCORE_THRESHOLD:
|
case PROP_SCORE_THRESHOLD:
|
||||||
GST_OBJECT_LOCK (self);
|
GST_OBJECT_LOCK (self);
|
||||||
@ -313,7 +320,7 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
|||||||
|
|
||||||
std::vector < GstMlBoundingBox > boxes =
|
std::vector < GstMlBoundingBox > boxes =
|
||||||
GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
|
GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
|
||||||
self->video_info.height, tmeta, self->label_file ? self->label_file : "",
|
self->video_info.height, tmeta, self->labels,
|
||||||
self->score_threshold);
|
self->score_threshold);
|
||||||
|
|
||||||
for (auto & b:boxes) {
|
for (auto & b:boxes) {
|
||||||
|
@ -52,6 +52,7 @@ struct _GstSsdObjectDetector
|
|||||||
{
|
{
|
||||||
GstBaseTransform basetransform;
|
GstBaseTransform basetransform;
|
||||||
gchar *label_file;
|
gchar *label_file;
|
||||||
|
gchar **labels;
|
||||||
gfloat score_threshold;
|
gfloat score_threshold;
|
||||||
gfloat confidence_threshold;
|
gfloat confidence_threshold;
|
||||||
gfloat iou_threshold;
|
gfloat iou_threshold;
|
||||||
|
@ -25,7 +25,7 @@ if onnxrt_dep.found()
|
|||||||
link_args : noseh_link_args,
|
link_args : noseh_link_args,
|
||||||
include_directories : [configinc, libsinc, cuda_stubinc],
|
include_directories : [configinc, libsinc, cuda_stubinc],
|
||||||
dependencies : [gstbase_dep, gstvideo_dep, gstanalytics_dep, onnxrt_dep,
|
dependencies : [gstbase_dep, gstvideo_dep, gstanalytics_dep, onnxrt_dep,
|
||||||
libm] + extra_deps,
|
libm, gio_dep] + extra_deps,
|
||||||
install : true,
|
install : true,
|
||||||
install_dir : plugins_install_dir,
|
install_dir : plugins_install_dir,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user