509 lines
17 KiB
C
509 lines
17 KiB
C
/*
|
|
* GStreamer gstreamer-classifiertensordecoder
|
|
* Copyright (C) 2025 Collabora Ltd.
|
|
* @author: Daniel Morin <daniel.morin@dmohub.org>
|
|
*
|
|
* gstclassifiertensordecoder.c
|
|
*
|
|
* This library is free software; you can redistribute it and/or
|
|
* modify it under the terms of the GNU Library General Public
|
|
* License as published by the Free Software Foundation; either
|
|
* version 2 of the License, or (at your option) any later version.
|
|
*
|
|
* This library is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
* Library General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU Library General Public
|
|
* License along with this library; if not, write to the
|
|
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
|
* Boston, MA 02110-1301, USA.
|
|
*/
|
|
|
|
/**
|
|
* SECTION:element-classifiertensordecoder.c
|
|
* @short_description: Decode tensors from classification model using a common
|
|
* tensor output format.
|
|
*
|
|
*
|
|
* This element can parse per-buffer inference tensor meta data generated by
|
|
* an upstream inference element.
|
|
*
|
|
* Tensor format must be:
|
|
* Dims: [batch-size, class_count]
|
|
* Datatype: float32
|
|
*
|
|
* Tensor [M,N]
|
|
* Batch 0 | Class 0 confidence level | ... | Class N confidence level |
|
|
* ...
|
|
* Batch M | Class 0 confidence level | ... | Class N confidence level |
|
|
*
|
|
* In-memory tensor format:
|
|
*
|
|
* |Batch 0, Class 0 confidence level |
|
|
* |Batch 0, ... |
|
|
* |Batch 0, Class N confidence level |
|
|
* | ... |
|
|
* |Batch M, Class 0 confidence level |
|
|
* |Batch M, ... |
|
|
* |Batch M, Class N confidence level |
|
|
*
|
|
*
|
|
* ## Example launch command:
|
|
* |[
|
|
* gst-launch-1.0 filesrc location=/onnx-models/images/bus.jpg \
|
|
* ! jpegdec \
|
|
* ! videoconvertscale add-borders=1 \
|
|
* ! onnxinference execution-provider=cpu \
|
|
* model-file=/onnx-models/models/mobilenet_v1.onnx \
|
|
* ! classifiertensordecoder labels-file=labels.txt ! fakesink \
|
|
* ]| This pipeline create an tensor-decoder for classification model
|
|
*
|
|
*/
|
|
|
|
#ifdef HAVE_CONFI_H
|
|
#include "config.h"
|
|
#endif
|
|
|
|
#include "gstclassifiertensordecoder.h"
|
|
#include <gst/gst.h>
|
|
#include <math.h>
|
|
#include <gst/analytics/analytics.h>
|
|
|
|
const gchar GST_MODEL_STD_IMAGE_CLASSIFICATION[] = "classification-generic-out";
|
|
|
|
GST_DEBUG_CATEGORY_STATIC (classifier_tensor_decoder_debug);
|
|
#define GST_CAT_DEFAULT classifier_tensor_decoder_debug
|
|
#define gst_classifier_tensor_decoder_parent_class parent_class
|
|
|
|
GST_ELEMENT_REGISTER_DEFINE (classifier_tensor_decoder,
|
|
"classifiertensordecoder", GST_RANK_PRIMARY,
|
|
GST_TYPE_CLASSIFIER_TENSOR_DECODER);
|
|
|
|
|
|
/* GstClassifierTensorDecoder properties */
|
|
enum
|
|
{
|
|
PROP_0,
|
|
PROP_THRESHOLD,
|
|
PROP_LABEL_FILE
|
|
};
|
|
|
|
static const float DEFAULT_THRESHOLD = 0.7f;
|
|
|
|
static GstStaticPadTemplate gst_classifier_tensor_decoder_src_template =
|
|
GST_STATIC_PAD_TEMPLATE ("src",
|
|
GST_PAD_SRC,
|
|
GST_PAD_ALWAYS,
|
|
GST_STATIC_CAPS_ANY);
|
|
|
|
static GstStaticPadTemplate gst_classifier_tensor_decoder_sink_template =
|
|
GST_STATIC_PAD_TEMPLATE ("sink",
|
|
GST_PAD_SINK,
|
|
GST_PAD_ALWAYS,
|
|
GST_STATIC_CAPS_ANY);
|
|
|
|
static void gst_classifier_tensor_decoder_set_property (GObject * object,
|
|
guint prop_id, const GValue * value, GParamSpec * pspec);
|
|
static void gst_classifier_tensor_decoder_get_property (GObject * object,
|
|
guint prop_id, GValue * value, GParamSpec * pspec);
|
|
|
|
static void gst_classifier_tensor_decoder_finalize (GObject * object);
|
|
|
|
static GstFlowReturn
|
|
gst_classifier_tensor_decoder_transform_ip (GstBaseTransform * trans,
|
|
GstBuffer * buf);
|
|
|
|
static GstStateChangeReturn
|
|
gst_classifier_tensor_decoder_change_state (GstElement * element,
|
|
GstStateChange transition);
|
|
|
|
#define softmax(len, values, results, max_val) \
|
|
gsize i; \
|
|
gfloat sum = 0.0; \
|
|
gfloat value; \
|
|
g_return_if_fail (values != NULL); \
|
|
g_return_if_fail (result != NULL); \
|
|
\
|
|
/* Calculate exponential of every value */ \
|
|
for (i = 0; i < len; i++) { \
|
|
value = values[i] / max_val; \
|
|
result[i] = exp (value); \
|
|
sum += result[i]; \
|
|
} \
|
|
\
|
|
/* Complete softmax */ \
|
|
for (i = 0; i < len; i++) { \
|
|
result[i] = result[i] / sum; \
|
|
}
|
|
|
|
static void
|
|
softmax_u8 (gsize len, const guint8 * values, gfloat * result)
|
|
{
|
|
softmax (len, values, results, 255.0);
|
|
}
|
|
|
|
static void
|
|
softmax_f32 (gsize len, const gfloat * values, gfloat * result)
|
|
{
|
|
softmax (len, values, results, 1.0);
|
|
}
|
|
|
|
G_DEFINE_TYPE (GstClassifierTensorDecoder, gst_classifier_tensor_decoder,
|
|
GST_TYPE_BASE_TRANSFORM);
|
|
|
|
static void
|
|
gst_classifier_tensor_decoder_class_init (GstClassifierTensorDecoderClass *
|
|
klass)
|
|
{
|
|
GObjectClass *gobject_class = (GObjectClass *) klass;
|
|
GstElementClass *element_class = (GstElementClass *) klass;
|
|
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
|
|
|
|
GST_DEBUG_CATEGORY_INIT (classifier_tensor_decoder_debug,
|
|
"classifiertensordecoder", 0,
|
|
"Tensor decoder for classification model with common output format");
|
|
|
|
gobject_class->set_property = gst_classifier_tensor_decoder_set_property;
|
|
gobject_class->get_property = gst_classifier_tensor_decoder_get_property;
|
|
gobject_class->finalize = gst_classifier_tensor_decoder_finalize;
|
|
|
|
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
|
PROP_THRESHOLD,
|
|
g_param_spec_float ("class-confidence-threshold",
|
|
"Class confidence threshold",
|
|
"Classes with a confidence level inferior to this threshold "
|
|
"will be excluded",
|
|
0.0, 1.0, DEFAULT_THRESHOLD,
|
|
(GParamFlags) (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
|
|
|
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
|
PROP_LABEL_FILE,
|
|
g_param_spec_string ("labels-file",
|
|
"Class labels file",
|
|
"Path to a file containing class label. COCO format",
|
|
NULL, (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
|
|
|
element_class->change_state = gst_classifier_tensor_decoder_change_state;
|
|
|
|
gst_element_class_set_static_metadata (element_class,
|
|
"classifiertensordecoder", "Tensordecoder",
|
|
"Decode tensors output from classification model using common format.\n"
|
|
"\tTensor format must be: \n" "\t\tDims: [batch-size, class_count]\n"
|
|
"\t\tDatatype: float32 \n" "\n" "\t\tTensor [M,N]\n"
|
|
"\t\t\tBatch 0 | Class 0 confidence level | ... | Class N-1 confidence level |\n"
|
|
"\t\t\t...\n"
|
|
"\t\t\tBatch M-1 | Class 0 confidence level | ... | Class N-1 confidence level |\n"
|
|
"\t\t\n" "\tIn-memory tensor format:\n" "\n"
|
|
"\t\t|Batch 0, Class 0 confidence level |\n"
|
|
"\t\t|Batch 0, ... |\n"
|
|
"\t\t|Batch 0, Class N-1 confidence level |\n"
|
|
"\t\t| ... |\n"
|
|
"\t\t|Batch M-1, Class 0 confidence level |\n"
|
|
"\t\t|Batch M-1, ... |\n"
|
|
"\t\t|Batch M-1, Class N-1 confidence level |\n" "\n" " model",
|
|
"Daniel Morin <daniel.morin@collabora.com>");
|
|
|
|
gst_element_class_add_pad_template (element_class,
|
|
gst_static_pad_template_get
|
|
(&gst_classifier_tensor_decoder_src_template));
|
|
gst_element_class_add_pad_template (element_class,
|
|
gst_static_pad_template_get
|
|
(&gst_classifier_tensor_decoder_sink_template));
|
|
|
|
basetransform_class->transform_ip =
|
|
GST_DEBUG_FUNCPTR (gst_classifier_tensor_decoder_transform_ip);
|
|
}
|
|
|
|
static void
|
|
gst_classifier_tensor_decoder_init (GstClassifierTensorDecoder * self)
|
|
{
|
|
self->threshold = DEFAULT_THRESHOLD;
|
|
self->labels_file = NULL;
|
|
self->softmax_res = NULL;
|
|
|
|
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), FALSE);
|
|
}
|
|
|
|
static void
|
|
gst_classifier_tensor_decoder_finalize (GObject * object)
|
|
{
|
|
GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (object);
|
|
|
|
g_free (self->labels_file);
|
|
G_OBJECT_CLASS (gst_classifier_tensor_decoder_parent_class)->finalize
|
|
(object);
|
|
}
|
|
|
|
static void
|
|
gst_classifier_tensor_decoder_set_property (GObject * object, guint prop_id,
|
|
const GValue * value, GParamSpec * pspec)
|
|
{
|
|
GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (object);
|
|
static GFileTest filetest = (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR);
|
|
|
|
switch (prop_id) {
|
|
case PROP_THRESHOLD:
|
|
self->threshold = g_value_get_float (value);
|
|
break;
|
|
case PROP_LABEL_FILE:
|
|
self->labels_file = g_strdup (g_value_get_string (value));
|
|
|
|
if (self->labels_file) {
|
|
if (!g_file_test (self->labels_file, filetest)) {
|
|
GST_ERROR_OBJECT (self, "Unable to load %s", self->labels_file);
|
|
g_free (g_steal_pointer (&self->labels_file));
|
|
}
|
|
} else {
|
|
GST_ERROR_OBJECT (self, "Invalid file");
|
|
}
|
|
break;
|
|
default:
|
|
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
|
break;
|
|
}
|
|
}
|
|
|
|
static void
|
|
gst_classifier_tensor_decoder_get_property (GObject * object, guint prop_id,
|
|
GValue * value, GParamSpec * pspec)
|
|
{
|
|
GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (object);
|
|
|
|
switch (prop_id) {
|
|
case PROP_THRESHOLD:
|
|
g_value_set_float (value, self->threshold);
|
|
break;
|
|
case PROP_LABEL_FILE:
|
|
g_value_set_string (value, self->labels_file);
|
|
break;
|
|
default:
|
|
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
|
break;
|
|
}
|
|
}
|
|
|
|
static guint
|
|
gst_classifier_tensor_decoder_load_labels (GstClassifierTensorDecoder * self)
|
|
{
|
|
gchar *content = NULL;
|
|
gchar **tokens = NULL;
|
|
gsize len;
|
|
GError *err = NULL;
|
|
GQuark val;
|
|
|
|
if (self->labels_file == NULL) {
|
|
GST_ERROR_OBJECT (self, "Missing label file");
|
|
return 0;
|
|
}
|
|
if (!g_file_get_contents (self->labels_file, &content, &len, &err)) {
|
|
GST_ERROR_OBJECT (self, "Could not load labels file %s: %s",
|
|
self->labels_file, err->message);
|
|
g_error_free (err);
|
|
return 0;
|
|
}
|
|
|
|
if (len == 0) {
|
|
GST_ERROR_OBJECT (self, "Labels file %s is empty", self->labels_file);
|
|
g_free (content);
|
|
return 0;
|
|
}
|
|
|
|
tokens = g_strsplit (content, "\n", 0);
|
|
g_free (content);
|
|
|
|
if (tokens[0] == NULL) {
|
|
GST_ERROR_OBJECT (self, "Labels file %s has no labels", self->labels_file);
|
|
g_free (content);
|
|
return 0;
|
|
}
|
|
|
|
self->class_quark = g_array_new (FALSE, FALSE, sizeof (GQuark));
|
|
|
|
for (int i = 0; tokens[i] != NULL && tokens[i][0] != '\0'; i++) {
|
|
val = g_quark_from_string (tokens[i]);
|
|
g_array_append_val (self->class_quark, val);
|
|
}
|
|
|
|
self->softmax_res = g_array_sized_new (FALSE, TRUE, sizeof (gfloat),
|
|
self->class_quark->len);
|
|
|
|
g_strfreev (tokens);
|
|
return self->class_quark->len;
|
|
}
|
|
|
|
static GstStateChangeReturn
|
|
gst_classifier_tensor_decoder_change_state (GstElement * element,
|
|
GstStateChange transition)
|
|
{
|
|
GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (element);
|
|
GstStateChangeReturn ret;
|
|
|
|
switch (transition) {
|
|
case GST_STATE_CHANGE_NULL_TO_READY:
|
|
if (!gst_classifier_tensor_decoder_load_labels (self)) {
|
|
return GST_STATE_CHANGE_FAILURE;
|
|
}
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
ret = GST_ELEMENT_CLASS (parent_class)->change_state (element, transition);
|
|
|
|
switch (transition) {
|
|
case GST_STATE_CHANGE_READY_TO_NULL:
|
|
g_array_free (self->class_quark, FALSE);
|
|
g_array_free (self->softmax_res, TRUE);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
static GstTensorMeta *
|
|
gst_classifier_tensor_decoder_get_tensor_meta (GstClassifierTensorDecoder *
|
|
self, GstBuffer * buf)
|
|
{
|
|
GstMeta *meta = NULL;
|
|
gpointer iter_state = NULL;
|
|
|
|
if (!gst_buffer_get_meta (buf, GST_TENSOR_META_API_TYPE)) {
|
|
GST_DEBUG_OBJECT (self,
|
|
"missing tensor meta from buffer %" GST_PTR_FORMAT, buf);
|
|
return NULL;
|
|
}
|
|
|
|
while ((meta = gst_buffer_iterate_meta_filtered (buf, &iter_state,
|
|
GST_TENSOR_META_API_TYPE))) {
|
|
GstTensorMeta *tensor_meta = (GstTensorMeta *) meta;
|
|
|
|
if (tensor_meta->num_tensors != 1)
|
|
continue;
|
|
|
|
gint index = gst_tensor_meta_get_index_from_id (tensor_meta,
|
|
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION));
|
|
|
|
if (index == -1)
|
|
continue;
|
|
|
|
return tensor_meta;
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
static GstFlowReturn
|
|
gst_classifier_tensor_decoder_decode (GstClassifierTensorDecoder * self,
|
|
GstBuffer * buf, GstAnalyticsRelationMeta * rmeta, GstTensorMeta * tmeta)
|
|
{
|
|
GstMapInfo map_info = GST_MAP_INFO_INIT;
|
|
gfloat max = 0.0;
|
|
gfloat *softmax_res = (gfloat *) self->softmax_res->data;
|
|
gsize len;
|
|
GQuark q, qmax;
|
|
gint max_idx = -1;
|
|
const GstTensor *tensor;
|
|
GstAnalyticsClsMtd cls_mtd;
|
|
const gsize DIMS[] = { 1, G_MAXSIZE };
|
|
|
|
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
|
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
|
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
|
|
if (tensor == NULL)
|
|
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
|
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
|
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);
|
|
if (tensor == NULL)
|
|
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
|
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
|
GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
|
|
if (tensor == NULL)
|
|
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
|
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
|
GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);
|
|
|
|
if (tensor == NULL) {
|
|
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
|
|
(NULL), ("Could not find classification tensor"));
|
|
return GST_FLOW_ERROR;
|
|
}
|
|
|
|
len = tensor->dims[tensor->num_dims - 1];
|
|
|
|
if (len != self->class_quark->len) {
|
|
GST_WARNING_OBJECT (self, "Labels file has size %zu, but the tensor has"
|
|
" %u entries, it is probably not the right labels file",
|
|
len, self->class_quark->len);
|
|
len = MIN (len, self->class_quark->len);
|
|
}
|
|
|
|
if (!gst_buffer_map (tensor->data, &map_info, GST_MAP_READ)) {
|
|
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
|
|
("Failed to map tensor data"));
|
|
return GST_FLOW_ERROR;
|
|
}
|
|
|
|
GST_TRACE_OBJECT (self, "Tensor shape dims %zu", tensor->num_dims);
|
|
|
|
if (gst_debug_category_get_threshold (GST_CAT_DEFAULT) >= GST_LEVEL_TRACE) {
|
|
for (gint i = 0; i < tensor->num_dims; i++) {
|
|
GST_TRACE_OBJECT (self, "Tensor dim %d: %zu", i, tensor->dims[i]);
|
|
}
|
|
}
|
|
|
|
switch (tensor->data_type) {
|
|
case GST_TENSOR_DATA_TYPE_FLOAT32:
|
|
softmax_f32 (len, (gfloat *) map_info.data, softmax_res);
|
|
break;
|
|
case GST_TENSOR_DATA_TYPE_UINT8:
|
|
softmax_u8 (len, (guint8 *) map_info.data, softmax_res);
|
|
break;
|
|
default:
|
|
g_return_val_if_reached (GST_FLOW_ERROR);
|
|
break;
|
|
}
|
|
gst_buffer_unmap (tensor->data, &map_info);
|
|
|
|
for (gint j = 0; j < len; j++) {
|
|
q = g_array_index (self->class_quark, GQuark, j);
|
|
if (softmax_res[j] > max) {
|
|
max = softmax_res[j];
|
|
max_idx = j;
|
|
qmax = q;
|
|
}
|
|
}
|
|
|
|
if (max_idx != -1) {
|
|
gst_analytics_relation_meta_add_one_cls_mtd (rmeta, max, qmax, &cls_mtd);
|
|
|
|
GST_LOG_OBJECT (self, "Max class is %d:%s with %f", max_idx,
|
|
g_quark_to_string (qmax), max);
|
|
}
|
|
|
|
return GST_FLOW_OK;
|
|
}
|
|
|
|
static GstFlowReturn
|
|
gst_classifier_tensor_decoder_transform_ip (GstBaseTransform * trans,
|
|
GstBuffer * buf)
|
|
{
|
|
GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (trans);
|
|
GstTensorMeta *tmeta;
|
|
GstAnalyticsRelationMeta *rmeta;
|
|
|
|
tmeta = gst_classifier_tensor_decoder_get_tensor_meta (self, buf);
|
|
if (tmeta != NULL) {
|
|
rmeta = gst_buffer_add_analytics_relation_meta (buf);
|
|
g_assert (rmeta != NULL);
|
|
} else {
|
|
GST_WARNING_OBJECT (trans, "missing tensor meta");
|
|
return TRUE;
|
|
}
|
|
|
|
return gst_classifier_tensor_decoder_decode (self, buf, rmeta, tmeta);
|
|
}
|