/* * GStreamer gstreamer-classifiertensordecoder * Copyright (C) 2025 Collabora Ltd. * @author: Daniel Morin * * gstclassifiertensordecoder.c * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Library General Public * License as published by the Free Software Foundation; either * version 2 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Library General Public License for more details. * * You should have received a copy of the GNU Library General Public * License along with this library; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301, USA. */ /** * SECTION:element-classifiertensordecoder.c * @short_description: Decode tensors from classification model using a common * tensor output format. * * * This element can parse per-buffer inference tensor meta data generated by * an upstream inference element. * * Tensor format must be: * Dims: [batch-size, class_count] * Datatype: float32 * * Tensor [M,N] * Batch 0 | Class 0 confidence level | ... | Class N confidence level | * ... * Batch M | Class 0 confidence level | ... | Class N confidence level | * * In-memory tensor format: * * |Batch 0, Class 0 confidence level | * |Batch 0, ... | * |Batch 0, Class N confidence level | * | ... | * |Batch M, Class 0 confidence level | * |Batch M, ... | * |Batch M, Class N confidence level | * * * ## Example launch command: * |[ * gst-launch-1.0 filesrc location=/onnx-models/images/bus.jpg \ * ! jpegdec \ * ! videoconvertscale add-borders=1 \ * ! onnxinference execution-provider=cpu \ * model-file=/onnx-models/models/mobilenet_v1.onnx \ * ! classifiertensordecoder labels-file=labels.txt ! fakesink \ * ]| This pipeline create an tensor-decoder for classification model * */ #ifdef HAVE_CONFI_H #include "config.h" #endif #include "gstclassifiertensordecoder.h" #include #include #include const gchar GST_MODEL_STD_IMAGE_CLASSIFICATION[] = "classification-generic-out"; GST_DEBUG_CATEGORY_STATIC (classifier_tensor_decoder_debug); #define GST_CAT_DEFAULT classifier_tensor_decoder_debug #define gst_classifier_tensor_decoder_parent_class parent_class GST_ELEMENT_REGISTER_DEFINE (classifier_tensor_decoder, "classifiertensordecoder", GST_RANK_PRIMARY, GST_TYPE_CLASSIFIER_TENSOR_DECODER); /* GstClassifierTensorDecoder properties */ enum { PROP_0, PROP_THRESHOLD, PROP_LABEL_FILE }; static const float DEFAULT_THRESHOLD = 0.7f; static GstStaticPadTemplate gst_classifier_tensor_decoder_src_template = GST_STATIC_PAD_TEMPLATE ("src", GST_PAD_SRC, GST_PAD_ALWAYS, GST_STATIC_CAPS_ANY); static GstStaticPadTemplate gst_classifier_tensor_decoder_sink_template = GST_STATIC_PAD_TEMPLATE ("sink", GST_PAD_SINK, GST_PAD_ALWAYS, GST_STATIC_CAPS_ANY); static void gst_classifier_tensor_decoder_set_property (GObject * object, guint prop_id, const GValue * value, GParamSpec * pspec); static void gst_classifier_tensor_decoder_get_property (GObject * object, guint prop_id, GValue * value, GParamSpec * pspec); static void gst_classifier_tensor_decoder_finalize (GObject * object); static GstFlowReturn gst_classifier_tensor_decoder_transform_ip (GstBaseTransform * trans, GstBuffer * buf); static GstStateChangeReturn gst_classifier_tensor_decoder_change_state (GstElement * element, GstStateChange transition); #define softmax(len, values, results, max_val) \ gsize i; \ gfloat sum = 0.0; \ gfloat value; \ g_return_if_fail (values != NULL); \ g_return_if_fail (result != NULL); \ \ /* Calculate exponential of every value */ \ for (i = 0; i < len; i++) { \ value = values[i] / max_val; \ result[i] = exp (value); \ sum += result[i]; \ } \ \ /* Complete softmax */ \ for (i = 0; i < len; i++) { \ result[i] = result[i] / sum; \ } static void softmax_u8 (gsize len, const guint8 * values, gfloat * result) { softmax (len, values, results, 255.0); } static void softmax_f32 (gsize len, const gfloat * values, gfloat * result) { softmax (len, values, results, 1.0); } G_DEFINE_TYPE (GstClassifierTensorDecoder, gst_classifier_tensor_decoder, GST_TYPE_BASE_TRANSFORM); static void gst_classifier_tensor_decoder_class_init (GstClassifierTensorDecoderClass * klass) { GObjectClass *gobject_class = (GObjectClass *) klass; GstElementClass *element_class = (GstElementClass *) klass; GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass; GST_DEBUG_CATEGORY_INIT (classifier_tensor_decoder_debug, "classifiertensordecoder", 0, "Tensor decoder for classification model with common output format"); gobject_class->set_property = gst_classifier_tensor_decoder_set_property; gobject_class->get_property = gst_classifier_tensor_decoder_get_property; gobject_class->finalize = gst_classifier_tensor_decoder_finalize; g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_THRESHOLD, g_param_spec_float ("class-confidence-threshold", "Class confidence threshold", "Classes with a confidence level inferior to this threshold " "will be excluded", 0.0, 1.0, DEFAULT_THRESHOLD, (GParamFlags) (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE, g_param_spec_string ("labels-file", "Class labels file", "Path to a file containing class label. COCO format", NULL, (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); element_class->change_state = gst_classifier_tensor_decoder_change_state; gst_element_class_set_static_metadata (element_class, "classifiertensordecoder", "Tensordecoder", "Decode tensors output from classification model using common format.\n" "\tTensor format must be: \n" "\t\tDims: [batch-size, class_count]\n" "\t\tDatatype: float32 \n" "\n" "\t\tTensor [M,N]\n" "\t\t\tBatch 0 | Class 0 confidence level | ... | Class N-1 confidence level |\n" "\t\t\t...\n" "\t\t\tBatch M-1 | Class 0 confidence level | ... | Class N-1 confidence level |\n" "\t\t\n" "\tIn-memory tensor format:\n" "\n" "\t\t|Batch 0, Class 0 confidence level |\n" "\t\t|Batch 0, ... |\n" "\t\t|Batch 0, Class N-1 confidence level |\n" "\t\t| ... |\n" "\t\t|Batch M-1, Class 0 confidence level |\n" "\t\t|Batch M-1, ... |\n" "\t\t|Batch M-1, Class N-1 confidence level |\n" "\n" " model", "Daniel Morin "); gst_element_class_add_pad_template (element_class, gst_static_pad_template_get (&gst_classifier_tensor_decoder_src_template)); gst_element_class_add_pad_template (element_class, gst_static_pad_template_get (&gst_classifier_tensor_decoder_sink_template)); basetransform_class->transform_ip = GST_DEBUG_FUNCPTR (gst_classifier_tensor_decoder_transform_ip); } static void gst_classifier_tensor_decoder_init (GstClassifierTensorDecoder * self) { self->threshold = DEFAULT_THRESHOLD; self->labels_file = NULL; self->softmax_res = NULL; gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), FALSE); } static void gst_classifier_tensor_decoder_finalize (GObject * object) { GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (object); g_free (self->labels_file); G_OBJECT_CLASS (gst_classifier_tensor_decoder_parent_class)->finalize (object); } static void gst_classifier_tensor_decoder_set_property (GObject * object, guint prop_id, const GValue * value, GParamSpec * pspec) { GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (object); static GFileTest filetest = (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR); switch (prop_id) { case PROP_THRESHOLD: self->threshold = g_value_get_float (value); break; case PROP_LABEL_FILE: self->labels_file = g_strdup (g_value_get_string (value)); if (self->labels_file) { if (!g_file_test (self->labels_file, filetest)) { GST_ERROR_OBJECT (self, "Unable to load %s", self->labels_file); g_free (g_steal_pointer (&self->labels_file)); } } else { GST_ERROR_OBJECT (self, "Invalid file"); } break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); break; } } static void gst_classifier_tensor_decoder_get_property (GObject * object, guint prop_id, GValue * value, GParamSpec * pspec) { GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (object); switch (prop_id) { case PROP_THRESHOLD: g_value_set_float (value, self->threshold); break; case PROP_LABEL_FILE: g_value_set_string (value, self->labels_file); break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); break; } } static guint gst_classifier_tensor_decoder_load_labels (GstClassifierTensorDecoder * self) { gchar *content = NULL; gchar **tokens = NULL; gsize len; GError *err = NULL; GQuark val; if (self->labels_file == NULL) { GST_ERROR_OBJECT (self, "Missing label file"); return 0; } if (!g_file_get_contents (self->labels_file, &content, &len, &err)) { GST_ERROR_OBJECT (self, "Could not load labels file %s: %s", self->labels_file, err->message); g_error_free (err); return 0; } if (len == 0) { GST_ERROR_OBJECT (self, "Labels file %s is empty", self->labels_file); g_free (content); return 0; } tokens = g_strsplit (content, "\n", 0); g_free (content); if (tokens[0] == NULL) { GST_ERROR_OBJECT (self, "Labels file %s has no labels", self->labels_file); g_free (content); return 0; } self->class_quark = g_array_new (FALSE, FALSE, sizeof (GQuark)); for (int i = 0; tokens[i] != NULL && tokens[i][0] != '\0'; i++) { val = g_quark_from_string (tokens[i]); g_array_append_val (self->class_quark, val); } self->softmax_res = g_array_sized_new (FALSE, TRUE, sizeof (gfloat), self->class_quark->len); g_strfreev (tokens); return self->class_quark->len; } static GstStateChangeReturn gst_classifier_tensor_decoder_change_state (GstElement * element, GstStateChange transition) { GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (element); GstStateChangeReturn ret; switch (transition) { case GST_STATE_CHANGE_NULL_TO_READY: if (!gst_classifier_tensor_decoder_load_labels (self)) { return GST_STATE_CHANGE_FAILURE; } break; default: break; } ret = GST_ELEMENT_CLASS (parent_class)->change_state (element, transition); switch (transition) { case GST_STATE_CHANGE_READY_TO_NULL: g_array_free (self->class_quark, FALSE); g_array_free (self->softmax_res, TRUE); break; default: break; } return ret; } static GstTensorMeta * gst_classifier_tensor_decoder_get_tensor_meta (GstClassifierTensorDecoder * self, GstBuffer * buf) { GstMeta *meta = NULL; gpointer iter_state = NULL; if (!gst_buffer_get_meta (buf, GST_TENSOR_META_API_TYPE)) { GST_DEBUG_OBJECT (self, "missing tensor meta from buffer %" GST_PTR_FORMAT, buf); return NULL; } while ((meta = gst_buffer_iterate_meta_filtered (buf, &iter_state, GST_TENSOR_META_API_TYPE))) { GstTensorMeta *tensor_meta = (GstTensorMeta *) meta; if (tensor_meta->num_tensors != 1) continue; gint index = gst_tensor_meta_get_index_from_id (tensor_meta, g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION)); if (index == -1) continue; return tensor_meta; } return NULL; } static GstFlowReturn gst_classifier_tensor_decoder_decode (GstClassifierTensorDecoder * self, GstBuffer * buf, GstAnalyticsRelationMeta * rmeta, GstTensorMeta * tmeta) { GstMapInfo map_info = GST_MAP_INFO_INIT; gfloat max = 0.0; gfloat *softmax_res = (gfloat *) self->softmax_res->data; gsize len; GQuark q, qmax; gint max_idx = -1; const GstTensor *tensor; GstAnalyticsClsMtd cls_mtd; const gsize DIMS[] = { 1, G_MAXSIZE }; tensor = gst_tensor_meta_get_typed_tensor (tmeta, g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION), GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL); if (tensor == NULL) tensor = gst_tensor_meta_get_typed_tensor (tmeta, g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION), GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS); if (tensor == NULL) tensor = gst_tensor_meta_get_typed_tensor (tmeta, g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION), GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL); if (tensor == NULL) tensor = gst_tensor_meta_get_typed_tensor (tmeta, g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION), GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS); if (tensor == NULL) { GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED, (NULL), ("Could not find classification tensor")); return GST_FLOW_ERROR; } len = tensor->dims[tensor->num_dims - 1]; if (len != self->class_quark->len) { GST_WARNING_OBJECT (self, "Labels file has size %zu, but the tensor has" " %u entries, it is probably not the right labels file", len, self->class_quark->len); len = MIN (len, self->class_quark->len); } if (!gst_buffer_map (tensor->data, &map_info, GST_MAP_READ)) { GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL), ("Failed to map tensor data")); return GST_FLOW_ERROR; } GST_TRACE_OBJECT (self, "Tensor shape dims %zu", tensor->num_dims); if (gst_debug_category_get_threshold (GST_CAT_DEFAULT) >= GST_LEVEL_TRACE) { for (gint i = 0; i < tensor->num_dims; i++) { GST_TRACE_OBJECT (self, "Tensor dim %d: %zu", i, tensor->dims[i]); } } switch (tensor->data_type) { case GST_TENSOR_DATA_TYPE_FLOAT32: softmax_f32 (len, (gfloat *) map_info.data, softmax_res); break; case GST_TENSOR_DATA_TYPE_UINT8: softmax_u8 (len, (guint8 *) map_info.data, softmax_res); break; default: g_return_val_if_reached (GST_FLOW_ERROR); break; } gst_buffer_unmap (tensor->data, &map_info); for (gint j = 0; j < len; j++) { q = g_array_index (self->class_quark, GQuark, j); if (softmax_res[j] > max) { max = softmax_res[j]; max_idx = j; qmax = q; } } if (max_idx != -1) { gst_analytics_relation_meta_add_one_cls_mtd (rmeta, max, qmax, &cls_mtd); GST_LOG_OBJECT (self, "Max class is %d:%s with %f", max_idx, g_quark_to_string (qmax), max); } return GST_FLOW_OK; } static GstFlowReturn gst_classifier_tensor_decoder_transform_ip (GstBaseTransform * trans, GstBuffer * buf) { GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (trans); GstTensorMeta *tmeta; GstAnalyticsRelationMeta *rmeta; tmeta = gst_classifier_tensor_decoder_get_tensor_meta (self, buf); if (tmeta != NULL) { rmeta = gst_buffer_add_analytics_relation_meta (buf); g_assert (rmeta != NULL); } else { GST_WARNING_OBJECT (trans, "missing tensor meta"); return TRUE; } return gst_classifier_tensor_decoder_decode (self, buf, rmeta, tmeta); }