tflite: Add TensorFlow Lite element

A new element wrapping the LiteRT (aka TensorFlow Lite) inference engine.
It currently supports only CPU.

Co-authored-by: Daniel Morin <daniel.morin@collabora.com>
Co-authored-by: Denis Shimizu <denis.shimizu@collabora.com>
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/8523>
This commit is contained in:
Denis Shimizu 2024-03-09 13:42:22 -03:00 committed by GStreamer Marge Bot
parent 5c188d90c0
commit 05782229ee
12 changed files with 1733 additions and 0 deletions

View File

@ -69,6 +69,7 @@ subdir('svtav1')
subdir('svthevcenc')
subdir('svtjpegxs')
subdir('teletextdec')
subdir('tflite')
subdir('ttml')
subdir('voaacenc')
subdir('voamrwbenc')

View File

@ -0,0 +1,9 @@
# GStreamer elements for TensorFlow Lite #
Given a TensorFlow Lite model, this element executes the inference to produce and add `GstTensorMeta` metadata to the buffer to be consumed by a tensor decoder
Requires the TensorFlow Lite library. Tested with TensorFlow r2.18
# To build TensorFlow Lite:
See detailed info on: [https://www.tensorflow.org/lite/guide/build_cmake](https://www.tensorflow.org/lite/guide/build_cmake)

View File

@ -0,0 +1,63 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstring>
#include <cstdlib>
#include "vsi_npu_custom_op.h"
namespace tflite {
namespace ops {
namespace custom {
namespace vsi_npu {
static void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteVsiNpuParams* data = reinterpret_cast<TfLiteVsiNpuParams*>(
malloc(sizeof(TfLiteVsiNpuParams) + sizeof(char) * length));
data->length = length;
data->binary = reinterpret_cast<char*>(data) + sizeof(TfLiteVsiNpuParams);
memcpy(reinterpret_cast<char*>(data->binary), buffer, length);
return reinterpret_cast<void*>(data);
}
static void Free(TfLiteContext* context, void* buffer) {
auto* data = reinterpret_cast<TfLiteVsiNpuParams*>(buffer);
delete data;
}
static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* data =
reinterpret_cast<TfLiteVsiNpuParams*>(node->user_data);
data->input_count = node->inputs->size;
data->output_cout = node->outputs->size;
return kTfLiteOk;
}
static TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace vsi_npu
TfLiteRegistration* Register_VSI_NPU_PRECOMPILED() {
static TfLiteRegistration r = {
vsi_npu::Init, vsi_npu::Free,
vsi_npu::Prepare,vsi_npu::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,48 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_VSI_NPU_CUSTOM_OP_H_
#define TENSORFLOW_LITE_DELEGATES_VSI_NPU_CUSTOM_OP_H_
#include "tensorflow/lite/c/common.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
static const char kNbgCustomOp[] = "vsi-npu";
typedef struct {
size_t length;
size_t input_count;
size_t output_cout;
char* binary;
} TfLiteVsiNpuParams;
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_VSI_NPU_PRECOMPILED(void);
} // namespace custom
} // namespace ops
} // namespace tflite
#endif //TENSORFLOW_LITE_DELEGATES_VSI_NPU_CUSTOM_OP_H_

View File

@ -0,0 +1,41 @@
/*
* GStreamer gstreamer-ml
* Copyright (C) 2021 Collabora Ltd
*
* gstml.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_ML_H__
#define __GST_ML_H__
/**
* GstMlInputImageFormat:
*
* @GST_ML_INPUT_IMAGE_FORMAT_HWC Height Width Channel (a.k.a. interleaved) format
* @GST_ML_INPUT_IMAGE_FORMAT_CHW Channel Height Width (a.k.a. planar) format
*
* Since: 1.20
*/
typedef enum {
GST_ML_INPUT_IMAGE_FORMAT_HWC,
GST_ML_INPUT_IMAGE_FORMAT_CHW,
} GstMlInputImageFormat;
#endif

View File

@ -0,0 +1,39 @@
/*
* GStreamer gstreamer-tflite
* Copyright (C) 2024 Collabora Ltd
*
* gsttflite.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "gsttfliteinference.h"
static gboolean
plugin_init (GstPlugin *plugin)
{
return GST_ELEMENT_REGISTER (tflite_inference, plugin);
}
GST_PLUGIN_DEFINE (GST_VERSION_MAJOR,
GST_VERSION_MINOR,
tflite,
"TFLITE neural network plugin",
plugin_init, VERSION, GST_LICENSE, GST_PACKAGE_NAME, GST_PACKAGE_ORIGIN);

View File

@ -0,0 +1,998 @@
/*
* GStreamer
* Copyright (C) 2024 Collabora Ltd.
*
* gsttfliteinference.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
/**
* SECTION:element-tfliteinference
* @short_description: Run TFLITE inference model on video buffers
*
* This element can apply an TFLITE model to video buffers. It attaches
* the tensor output to the buffer as a @ref GstTensorMeta.
*
* To install TFLITE on your system, follow the instructions in the
* README.md in with this plugin.
*
* ## Example launch command:
*
* GST_DEBUG=ssdobjectdetector:5 \
* gst-launch-1.0 filesrc location=tflite-models/images/bus.jpg ! \
* jpegdec ! videoconvert ! tfliteinference model-file=tflite-models/models/ssd_mobilenet_v1_coco.tflite ! \
* ssdobjectdetector label-file=tflite-models/labels/COCO_classes.txt ! videoconvert ! imagefreeze ! autovideosink
*
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "tensorflow/lite/c/c_api.h"
#include <gst/gst.h>
#include <gst/video/video.h>
#include "gsttfliteinference.h"
#include "modelinfo.h"
#include <tensorflow/lite/c/common.h>
#define DEFAULT_MODEL_FILE ""
#define DEFAULT_THREADS 0
/*
* GstTFliteInference:
*
* @model_file model file
* @tflite_client opaque pointer to TFLITE client
* @tflite_disabled true if inference is disabled
* @video_info @ref GstVideoInfo of sink caps
*/
typedef struct _GstTFliteInferencePrivate
{
GstBaseTransform basetransform;
gchar *model_file;
gsize numberOfThreads;
gchar *vxdelegate;
gboolean planar;
GPtrArray *tensor_templates;
TfLiteInterpreter *interpreter;
TfLiteInterpreterOptions *interpreter_options;
TfLiteModel *model;
gboolean tflite_disabled;
GstVideoInfo video_info;
guint8 *dest;
GstCaps *model_caps;
gint channels;
gdouble *means;
gdouble *stddevs;
} GstTFliteInferencePrivate;
GST_DEBUG_CATEGORY (tflite_inference_debug);
#define GST_CAT_DEFAULT tflite_inference_debug
GST_ELEMENT_REGISTER_DEFINE (tflite_inference, "tfliteinference",
GST_RANK_NONE, GST_TYPE_TFLITE_INFERENCE);
/* GstTFliteInference properties */
enum
{
PROP_0,
PROP_MODEL_FILE,
PROP_THREADS,
};
#define VIDEO_CAPS GST_VIDEO_CAPS_MAKE ("{ RGB, RGBA, BGR, BGRA }")
static GstStaticPadTemplate gst_tflite_inference_src_template =
GST_STATIC_PAD_TEMPLATE ("src",
GST_PAD_SRC,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (VIDEO_CAPS)
);
static GstStaticPadTemplate gst_tflite_inference_sink_template =
GST_STATIC_PAD_TEMPLATE ("sink",
GST_PAD_SINK,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (VIDEO_CAPS)
);
static gboolean gst_tflite_inference_start (GstBaseTransform * trans);
static gboolean gst_tflite_inference_stop (GstBaseTransform * trans);
static void gst_tflite_inference_set_property (GObject * object,
guint prop_id, const GValue * value, GParamSpec * pspec);
static void gst_tflite_inference_get_property (GObject * object,
guint prop_id, GValue * value, GParamSpec * pspec);
static void gst_tflite_inference_finalize (GObject * object);
static GstFlowReturn gst_tflite_inference_transform_ip (GstBaseTransform *
trans, GstBuffer * buf);
static gboolean gst_tflite_inference_process (GstBaseTransform * trans,
GstBuffer * buf);
static GstCaps *gst_tflite_inference_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
static gboolean
gst_tflite_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps);
G_DEFINE_TYPE_WITH_PRIVATE (GstTFliteInference, gst_tflite_inference,
GST_TYPE_BASE_TRANSFORM);
static void
gst_tflite_inference_class_init (GstTFliteInferenceClass * klass)
{
GObjectClass *gobject_class = (GObjectClass *) klass;
GstElementClass *element_class = (GstElementClass *) klass;
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
GST_DEBUG_CATEGORY_INIT (tflite_inference_debug, "tfliteinference",
0, "tflite_inference");
gobject_class->set_property = gst_tflite_inference_set_property;
gobject_class->get_property = gst_tflite_inference_get_property;
gobject_class->finalize = gst_tflite_inference_finalize;
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_MODEL_FILE,
g_param_spec_string ("model-file",
"TFLITE model file", "TFLITE model file", DEFAULT_MODEL_FILE,
(GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_THREADS,
g_param_spec_int ("threads",
"Number of Threads",
"Set the number of threads to be used by the TFLITE inference (-1 for auto)",
-1, G_MAXINT, DEFAULT_THREADS,
(GParamFlags) (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "tfliteinference",
"Filter/Effect",
"Apply neural network to video frames and create tensor output",
"Denis Shimizu <denis.shimizu@collabora.com>, "
"Aaron Boxer <aaron.boxer@collabora.com>,"
"Daniel Morin <daniel.morin@collabora.com>");
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_tflite_inference_sink_template));
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_tflite_inference_src_template));
basetransform_class->transform_ip =
GST_DEBUG_FUNCPTR (gst_tflite_inference_transform_ip);
basetransform_class->transform_caps =
GST_DEBUG_FUNCPTR (gst_tflite_inference_transform_caps);
basetransform_class->set_caps =
GST_DEBUG_FUNCPTR (gst_tflite_inference_set_caps);
basetransform_class->start = GST_DEBUG_FUNCPTR (gst_tflite_inference_start);
basetransform_class->stop = GST_DEBUG_FUNCPTR (gst_tflite_inference_stop);
}
static gboolean
gst_tflite_inference_has_session (GstTFliteInference * self)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
return priv->interpreter != NULL;
}
static void
gst_tflite_inference_init (GstTFliteInference * self)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
priv->numberOfThreads = DEFAULT_THREADS;
priv->tensor_templates = g_ptr_array_new_with_free_func ((GDestroyNotify)
gst_tensor_free);
priv->tflite_disabled = TRUE;
}
static void
gst_tflite_inference_finalize (GObject * object)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (object);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
g_free (priv->model_file);
g_ptr_array_unref (priv->tensor_templates);
G_OBJECT_CLASS (gst_tflite_inference_parent_class)->finalize (object);
}
static void
gst_tflite_inference_set_property (GObject * object, guint prop_id,
const GValue * value, GParamSpec * pspec)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (object);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
const gchar *filename;
switch (prop_id) {
case PROP_MODEL_FILE:
filename = g_value_get_string (value);
if (filename
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
if (priv->model_file)
g_free (priv->model_file);
priv->model_file = g_strdup (filename);
priv->tflite_disabled = FALSE;
} else {
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
}
break;
case PROP_THREADS:
priv->numberOfThreads = g_value_get_int (value);
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static void
gst_tflite_inference_get_property (GObject * object, guint prop_id,
GValue * value, GParamSpec * pspec)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (object);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
switch (prop_id) {
case PROP_MODEL_FILE:
g_value_set_string (value, priv->model_file);
break;
case PROP_THREADS:
g_value_set_int (value, priv->numberOfThreads);
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static GstTensorDataType
gst_tflite_convert_data_type (TfLiteType type)
{
switch (type) {
case kTfLiteFloat32:
return GST_TENSOR_DATA_TYPE_FLOAT32;
case kTfLiteInt32:
return GST_TENSOR_DATA_TYPE_INT32;
case kTfLiteUInt8:
return GST_TENSOR_DATA_TYPE_UINT8;
case kTfLiteInt64:
return GST_TENSOR_DATA_TYPE_INT64;
case kTfLiteInt16:
return GST_TENSOR_DATA_TYPE_INT16;
case kTfLiteInt8:
return GST_TENSOR_DATA_TYPE_INT8;
case kTfLiteFloat16:
return GST_TENSOR_DATA_TYPE_FLOAT16;
case kTfLiteFloat64:
return GST_TENSOR_DATA_TYPE_FLOAT64;
case kTfLiteUInt64:
return GST_TENSOR_DATA_TYPE_UINT64;
case kTfLiteUInt32:
return GST_TENSOR_DATA_TYPE_UINT32;
case kTfLiteUInt16:
return GST_TENSOR_DATA_TYPE_UINT16;
case kTfLiteInt4:
return GST_TENSOR_DATA_TYPE_INT4;
#ifdef TFLITE_HAS_BFLOAT16
case kTfLiteBFloat16:
return GST_TENSOR_DATA_TYPE_BFLOAT16;
#endif
default:
GST_FIXME ("GstTensorDataType currently does not have a mapping \
for this type.");
g_assert_not_reached ();
}
}
static gboolean
convert_tensor_info (const TfLiteTensor * tflite_tensor,
const gchar ** tname, GstTensorDataType * data_type,
gsize * dims_count, gsize ** out_dims)
{
gsize j;
gsize *dims;
if (tname)
*tname = TfLiteTensorName (tflite_tensor);
*dims_count = TfLiteTensorNumDims (tflite_tensor);
if (*dims_count == 0)
return FALSE;
dims = *out_dims = (gsize *) g_malloc0_n (*dims_count, sizeof (gsize));
if (tflite_tensor->dims_signature && tflite_tensor->dims_signature->size) {
for (j = 0; j < *dims_count; j++) {
if (tflite_tensor->dims_signature->data[j] < 0)
dims[j] = G_MAXSIZE;
else
dims[j] = tflite_tensor->dims_signature->data[j];
}
} else {
for (j = 0; j < *dims_count; j++)
dims[j] = TfLiteTensorDim (tflite_tensor, j);
}
*data_type = gst_tflite_convert_data_type (TfLiteTensorType (tflite_tensor));
return TRUE;
}
static gchar *
build_dims_str (gsize dims_count, gsize * dims)
{
GString *dims_gstr = g_string_new ("");
gsize j;
if (dims_count == 0)
goto done;
if (dims[0] == G_MAXSIZE)
g_string_append (dims_gstr, "-1");
else
g_string_append_printf (dims_gstr, "%zu", dims[0]);
for (j = 1; j < dims_count; j++)
if (dims[j] == G_MAXSIZE)
g_string_append (dims_gstr, ",-1");
else
g_string_append_printf (dims_gstr, ",%zu", dims[j]);
done:
return g_string_free (dims_gstr, FALSE);
}
static gboolean
_get_input_params (GstTFliteInference * self, GstTensorDataType * data_type,
gint * width, gint * height, const gchar ** gst_format,
gint * channels, gboolean * planar)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
const TfLiteTensor *input_tensor;
gint i_size = TfLiteInterpreterGetInputTensorCount (priv->interpreter);
gsize dims_count;
gsize *dims = NULL;
if (i_size != 1) {
GST_ERROR_OBJECT (self, "Currently only support model with a single"
" input tensor, but model has %d", i_size);
goto reject;
}
input_tensor = TfLiteInterpreterGetInputTensor (priv->interpreter, 0);
if (!convert_tensor_info (input_tensor, NULL, data_type, &dims_count, &dims)) {
GST_ERROR_OBJECT (self, "Input tensor has no dimensions, rejecting");
goto reject;
}
if (dims_count < 2 || dims_count > 4) {
GST_ERROR_OBJECT (self,
"Don't know how to interpret tensors with %zu dimensions", dims_count);
goto reject;
}
*planar = FALSE;
switch (dims_count) {
case 2:
*gst_format = "GRAY8";
*height = dims[0];
*width = dims[1];
break;
case 3:
if (dims[0] == 1 || dims[0] == 3) {
*channels = dims[0];
if (dims[0] == 1) {
*gst_format = "GRAY8";
} else {
*gst_format = "RGBP";
*planar = TRUE;
}
*height = dims[1];
*width = dims[2];
} else if (dims[2] == 1 || dims[2] == 3) {
*channels = dims[2];
if (dims[2] == 1)
*gst_format = "GRAY";
else
*gst_format = "RGB";
*height = dims[0];
*width = dims[1];
} else {
GST_ERROR_OBJECT (self, "Don't know how to interpret dims");
goto reject;
}
break;
case 4:
/* Assuming dims[0] is a batch */
if (dims[1] == 1 || dims[1] == 3) {
*channels = dims[1];
*planar = TRUE;
*height = dims[2];
*width = dims[3];
} else if (dims[3] == 1 || dims[3] == 3) {
*channels = dims[3];
*height = dims[1];
*width = dims[2];
} else {
GST_ERROR_OBJECT (self, "Don't know how to interpret dims");
goto reject;
}
if (*channels == 1) {
*gst_format = "GRAY8";
*planar = FALSE;
} else if (*channels == 3) {
if (*planar)
*gst_format = "RGBP";
else
*gst_format = "RGB";
} else {
g_assert_not_reached ();
}
break;
}
g_free (dims);
return TRUE;
reject:
g_free (dims);
return FALSE;
}
static gboolean
gst_tflite_inference_start (GstBaseTransform * trans)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
gboolean ret = FALSE;
ModelInfo *modelinfo = NULL;
gint i_size, o_size;
GstTFliteInferenceClass *klass = GST_TFLITE_INFERENCE_GET_CLASS (self);
GST_OBJECT_LOCK (self);
if (gst_tflite_inference_has_session (self)) {
ret = TRUE;
goto done;
}
if (priv->model_file == NULL) {
GST_ERROR_OBJECT (self, "model-file property not set");
goto done;
}
priv->model = TfLiteModelCreateFromFile (priv->model_file);
if (!priv->model) {
GST_ERROR_OBJECT (self, "Failed to mmap model %s", priv->model_file);
goto error;
}
GST_DEBUG_OBJECT (self, "Loaded model %s", priv->model_file);
priv->interpreter_options = TfLiteInterpreterOptionsCreate ();
if (priv->numberOfThreads != 0) {
TfLiteInterpreterOptionsSetNumThreads (priv->interpreter_options,
priv->numberOfThreads);
}
priv->interpreter = TfLiteInterpreterCreate (priv->model,
priv->interpreter_options);
if (!priv->interpreter) {
GST_ERROR_OBJECT (self, "Failed to construct interpreter");
goto error;
}
modelinfo = modelinfo_load (priv->model_file);
if (!modelinfo) {
GST_ERROR_OBJECT (self, "Can't find modelinfo for %s", priv->model_file);
goto error;
}
i_size = TfLiteInterpreterGetInputTensorCount (priv->interpreter);
if (i_size != 1) {
GST_ERROR_OBJECT (self, "Currently only support model with a single"
" input tensor, but model has %d", i_size);
goto error;
}
{
const guint i = 0;
const TfLiteTensor *tflite_tensor =
TfLiteInterpreterGetInputTensor (priv->interpreter, i);
const gchar *tname;
GstTensorDataType data_type;
gsize dims_count;
gsize *dims;
gchar *tensor_name = NULL;
gint width = 0, height = 0;
const gchar *gst_format = NULL;
guint num_means, num_stddevs;
if (!_get_input_params (self, &data_type, &width, &height, &gst_format,
&priv->channels, &priv->planar)) {
GST_ERROR_OBJECT (self, "Failed to get parameters");
goto error;
}
if (!convert_tensor_info (tflite_tensor, &tname, &data_type,
&dims_count, &dims)) {
GST_ERROR_OBJECT (self, "Rejecting input_tensor[%d]:%s with no dims",
i, tname);
goto error;
}
tensor_name = modelinfo_find_tensor_name (modelinfo,
MODELINFO_DIRECTION_INPUT, i, tname, data_type, dims_count, dims);
if (tensor_name == NULL) {
gchar *dims_str = build_dims_str (dims_count, dims);
GST_DEBUG_OBJECT (self,
"Model info file doesn't contain info for input_tensor[%u]:%s matching the"
" type %s and dims %s", i, tname,
gst_tensor_data_type_get_name (data_type), dims_str);
g_free (dims);
g_free (dims_str);
} else {
num_means = modelinfo_get_normalization_means (modelinfo,
tensor_name, priv->channels, &priv->means);
if (num_means != priv->channels) {
priv->means = g_renew (gdouble, priv->means, priv->channels);
for (guint j = 1; j < priv->channels; j++)
priv->means[j] = priv->means[0];
}
num_stddevs = modelinfo_get_normalization_stddevs (modelinfo,
tensor_name, priv->channels, &priv->stddevs);
if (num_stddevs != priv->channels) {
priv->stddevs = g_renew (gdouble, priv->stddevs, priv->channels);
for (guint j = 1; j < priv->channels; j++)
priv->stddevs[j] = priv->stddevs[0];
}
}
gst_clear_caps (&priv->model_caps);
priv->model_caps = gst_caps_new_empty_simple ("video/x-raw");
if (width && height)
gst_caps_set_simple (priv->model_caps, "width", G_TYPE_INT, width,
"height", G_TYPE_INT, height, NULL);
if (data_type == GST_TENSOR_DATA_TYPE_UINT8 && gst_format &&
priv->means == NULL && priv->stddevs == NULL)
gst_caps_set_simple (priv->model_caps, "format", G_TYPE_STRING,
gst_format, NULL);
g_free (tensor_name);
}
if (TfLiteInterpreterAllocateTensors (priv->interpreter) != kTfLiteOk) {
GST_ERROR_OBJECT (self, "Failed to allocate tensors");
goto error;
}
o_size = TfLiteInterpreterGetOutputTensorCount (priv->interpreter);
for (guint i = 0; i < o_size; i++) {
const TfLiteTensor *tflite_tensor =
TfLiteInterpreterGetOutputTensor (priv->interpreter, i);
const gchar *tname;
GstTensorDataType data_type;
gsize dims_count;
gsize *dims;
gchar *tensor_name = NULL;
if (!convert_tensor_info (tflite_tensor, &tname, &data_type,
&dims_count, &dims)) {
GST_WARNING_OBJECT (self, "Skipping output_tensor[%d]:%s with no dims",
i, tname);
continue;
}
tensor_name = modelinfo_find_tensor_name (modelinfo,
MODELINFO_DIRECTION_OUTPUT, i, tname, data_type, dims_count, dims);
gchar *dims_str = build_dims_str (dims_count, dims);
if (tensor_name == NULL) {
GST_ERROR_OBJECT (self,
"Model info file doesn't contain info for output_tensor[%u]:%s matching the"
" type %s and dims %s", i, tname,
gst_tensor_data_type_get_name (data_type), dims_str);
g_free (dims);
g_free (dims_str);
g_ptr_array_set_size (priv->tensor_templates, 0);
goto error;
}
GstTensor *t = gst_tensor_alloc (dims_count);
gchar *id = modelinfo_get_id (modelinfo, tensor_name);
GST_DEBUG_OBJECT (self, "Mapping output_tensor[%d]:%s of type %s and"
" dims %s to id %s", i, tname,
gst_tensor_data_type_get_name (data_type), dims_str, id);
g_free (id);
g_free (dims_str);
t->id = modelinfo_get_quark_id (modelinfo, tensor_name);
t->layout = GST_TENSOR_LAYOUT_CONTIGUOUS;
t->data_type = data_type;
t->dims_order = GST_TENSOR_DIM_ORDER_ROW_MAJOR;
memcpy (t->dims, dims, sizeof (gsize) * t->num_dims);
g_free (dims);
g_ptr_array_add (priv->tensor_templates, t);
g_free (tensor_name);
}
TfLiteTensor *itensor = TfLiteInterpreterGetInputTensor (priv->interpreter,
0);
if (TfLiteTensorType (itensor) == kTfLiteFloat32) {
GST_DEBUG_OBJECT (self, "Floating point Tensorflow Lite Model");
}
ret = TRUE;
done:
if (modelinfo)
modelinfo_free (modelinfo);
GST_OBJECT_UNLOCK (self);
return ret;
error:
GST_ERROR_OBJECT (self,
"Unable to create TFLITE session. Inference is disabled.");
GST_BASE_TRANSFORM_GET_CLASS (self)->stop (trans);
goto done;
}
static gboolean
gst_tflite_inference_stop (GstBaseTransform * trans)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
if (priv->interpreter)
TfLiteInterpreterDelete (priv->interpreter);
priv->interpreter = NULL;
if (priv->interpreter_options)
TfLiteInterpreterOptionsDelete (priv->interpreter_options);
priv->interpreter_options = NULL;
if (priv->model)
TfLiteModelDelete (priv->model);
priv->model = NULL;
gst_clear_caps (&priv->model_caps);
g_ptr_array_set_size (priv->tensor_templates, 0);
return TRUE;
}
static GstCaps *
gst_tflite_inference_transform_caps (GstBaseTransform * trans,
GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
GstCaps *other_caps;
if (priv->model_caps == NULL) {
other_caps = gst_caps_ref (caps);
goto done;
}
GST_DEBUG_OBJECT (self, "Applying caps restrictions: %" GST_PTR_FORMAT,
priv->model_caps);
other_caps = gst_caps_intersect_full (caps, priv->model_caps,
GST_CAPS_INTERSECT_FIRST);
done:
if (filter_caps) {
GstCaps *tmp = gst_caps_intersect_full (other_caps, filter_caps,
GST_CAPS_INTERSECT_FIRST);
gst_caps_replace (&other_caps, tmp);
gst_caps_unref (tmp);
}
return other_caps;
}
static gboolean
gst_tflite_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
if (!gst_video_info_from_caps (&priv->video_info, incaps)) {
GST_ERROR_OBJECT (self, "Failed to parse caps");
return FALSE;
}
return TRUE;
}
static GstFlowReturn
gst_tflite_inference_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
{
if (!gst_base_transform_is_passthrough (trans)
&& !gst_tflite_inference_process (trans, buf)) {
GST_ELEMENT_ERROR (trans, STREAM, FAILED,
(NULL), ("TFLITE inference failed"));
return GST_FLOW_ERROR;
}
return GST_FLOW_OK;
}
#define _convert_image_remove_alpha(Type, dst, srcPtr, \
srcSamplesPerPixel, stride, means, stddevs) \
G_STMT_START { \
size_t destIndex = 0; \
Type tmp; \
\
if (!priv->planar) { \
for (int32_t j = 0; j < dstHeight; ++j) { \
for (int32_t i = 0; i < dstWidth; ++i) { \
for (int32_t k = 0; k < dstChannels; ++k) { \
tmp = *srcPtr[k]; \
tmp += means[k]; \
dst[destIndex++] = (Type)(tmp / stddevs[k]); \
srcPtr[k] += srcSamplesPerPixel; \
} \
} \
/* correct for stride */ \
for (uint32_t k = 0; k < 3; ++k) \
srcPtr[k] += stride - srcSamplesPerPixel * dstWidth; \
} \
} else { \
size_t frameSize = dstWidth * dstHeight; \
Type *destPtr[3] = { dst, dst + frameSize, dst + 2 * frameSize }; \
for (int32_t j = 0; j < dstHeight; ++j) { \
for (int32_t i = 0; i < dstWidth; ++i) { \
for (int32_t k = 0; k < dstChannels; ++k) { \
tmp = *srcPtr[k]; \
tmp += means[k]; \
destPtr[k][destIndex] = (Type)(tmp / stddevs[k]); \
srcPtr[k] += srcSamplesPerPixel; \
} \
destIndex++; \
} \
/* correct for stride */ \
for (uint32_t k = 0; k < 3; ++k) \
srcPtr[k] += stride - srcSamplesPerPixel * dstWidth; \
} \
} \
} \
G_STMT_END;
static void
gst_tflite_inference_convert_image_remove_alpha_u8 (GstTFliteInference * self,
guint8 * dst, gint dstWidth, gint dstHeight, gint dstChannels,
guint8 ** srcPtr, guint8 srcSamplesPerPixel,
guint32 stride, const gdouble * means, const gdouble * stddevs)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
static const gdouble zeros[] = { 0, 0, 0, 0 };
static const gdouble ones[] = { 1.0, 1.0, 1.0, 1.0 };
if (means == NULL)
means = zeros;
if (stddevs == NULL)
stddevs = ones;
_convert_image_remove_alpha (guint8, dst, srcPtr, srcSamplesPerPixel,
stride, means, stddevs);
}
static void
gst_tflite_inference_convert_image_remove_alpha_f32 (GstTFliteInference * self,
gfloat * dst, gint dstWidth, gint dstHeight, gint dstChannels,
guint8 ** srcPtr, guint8 srcSamplesPerPixel,
guint32 stride, const gdouble * means, const gdouble * stddevs)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
static const gdouble zeros[] = { 0, 0, 0, 0 };
static const gdouble two_five_fives[] = { 255.0, 255.0, 255.0, 255.0 };
if (means == NULL)
means = zeros;
if (stddevs == NULL)
stddevs = two_five_fives;
_convert_image_remove_alpha (gfloat, dst, srcPtr, srcSamplesPerPixel,
stride, means, stddevs);
}
static gboolean
gst_tflite_inference_process (GstBaseTransform * trans, GstBuffer * buf)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
GstMapInfo info;
guint8 *srcPtr[3];
gsize srcSamplesPerPixel = 3;
GstTensorDataType datatype;
if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
// <==
srcPtr[0] = info.data;
srcPtr[1] = info.data + 1;
srcPtr[2] = info.data + 2;
switch (priv->video_info.finfo->format) {
case GST_VIDEO_FORMAT_RGBA:
srcSamplesPerPixel = 4;
break;
case GST_VIDEO_FORMAT_BGRA:
srcSamplesPerPixel = 4;
srcPtr[0] = info.data + 2;
srcPtr[1] = info.data + 1;
srcPtr[2] = info.data + 0;
break;
case GST_VIDEO_FORMAT_ARGB:
srcSamplesPerPixel = 4;
srcPtr[0] = info.data + 1;
srcPtr[1] = info.data + 2;
srcPtr[2] = info.data + 3;
break;
case GST_VIDEO_FORMAT_ABGR:
srcSamplesPerPixel = 4;
srcPtr[0] = info.data + 3;
srcPtr[1] = info.data + 2;
srcPtr[2] = info.data + 1;
break;
case GST_VIDEO_FORMAT_BGR:
srcPtr[0] = info.data + 2;
srcPtr[1] = info.data + 1;
srcPtr[2] = info.data + 0;
break;
default:
break;
}
TfLiteTensor *tensor = TfLiteInterpreterGetInputTensor (priv->interpreter,
0);
guint width = GST_VIDEO_INFO_WIDTH (&priv->video_info);
guint height = GST_VIDEO_INFO_HEIGHT (&priv->video_info);
guint32 stride = priv->video_info.stride[0];
guint channels;
if (GST_VIDEO_INFO_IS_GRAY (&priv->video_info)) {
channels = 1;
} else if (GST_VIDEO_INFO_IS_RGB (&priv->video_info)) {
channels = 3;
} else {
g_assert_not_reached ();
}
datatype = gst_tflite_convert_data_type (TfLiteTensorType (tensor));
switch (datatype) {
case GST_TENSOR_DATA_TYPE_UINT8:{
uint8_t *dest = (uint8_t *) TfLiteTensorData (tensor);
if (dest == NULL)
return false;
gst_tflite_inference_convert_image_remove_alpha_u8 (self,
dest, width, height, channels, srcPtr,
srcSamplesPerPixel, stride, priv->means, priv->stddevs);
break;
}
case GST_TENSOR_DATA_TYPE_FLOAT32:{
float *dest = (float *) TfLiteTensorData (tensor);
if (dest == NULL)
return false;
gst_tflite_inference_convert_image_remove_alpha_f32 (self, dest,
width, height, channels, srcPtr,
srcSamplesPerPixel, stride, priv->means, priv->stddevs);
break;
}
default:{
GST_ERROR_OBJECT (self, "Data type not handled");
return false;
}
break;
}
/* Run inference */
if (TfLiteInterpreterInvoke (priv->interpreter) != kTfLiteOk) {
GST_ERROR_OBJECT (self, "Failed to invoke tflite!");
return false;
}
gsize num_tensors =
TfLiteInterpreterGetOutputTensorCount (priv->interpreter);
g_assert (num_tensors == priv->tensor_templates->len);
GstTensor **tensors =
(GstTensor **) g_malloc0_n (num_tensors, sizeof (gpointer));
for (size_t i = 0; i < num_tensors; i++) {
const TfLiteTensor *output_tensor =
TfLiteInterpreterGetOutputTensor (priv->interpreter, i);
tensors[i] = gst_tensor_alloc (TfLiteTensorNumDims (output_tensor));
memcpy (tensors[i], g_ptr_array_index (priv->tensor_templates, i),
sizeof (GstTensor));
tensors[i]->num_dims = TfLiteTensorNumDims (output_tensor);
for (gsize j = 0; j < tensors[i]->num_dims; j++)
tensors[i]->dims[j] = TfLiteTensorDim (output_tensor, j);;
tensors[i]->data =
gst_buffer_new_allocate (NULL, TfLiteTensorByteSize (output_tensor),
NULL);
gst_buffer_fill (tensors[i]->data, 0, TfLiteTensorData (output_tensor),
TfLiteTensorByteSize (output_tensor));
}
GstTensorMeta *tmeta = gst_buffer_add_tensor_meta (buf);
gst_tensor_meta_set (tmeta, num_tensors, tensors);
if (!tmeta)
return FALSE;
GST_TRACE_OBJECT (trans, "Num tensors: %zu", tmeta->num_tensors);
gst_buffer_unmap (buf, &info);
}
return TRUE;
}

View File

@ -0,0 +1,43 @@
/*
* GStreamer gstreamer-tfliteinference
* Copyright (C) 2024 Collabora Ltd
*
* gsttfliteinference.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_TFLITE_INFERENCE_H__
#define __GST_TFLITE_INFERENCE_H__
#include <gst/gst.h>
#include <gst/base/base.h>
G_BEGIN_DECLS
#define GST_TYPE_TFLITE_INFERENCE (gst_tflite_inference_get_type())
G_DECLARE_DERIVABLE_TYPE (GstTFliteInference, gst_tflite_inference, GST,
TFLITE_INFERENCE, GstBaseTransform)
GST_ELEMENT_REGISTER_DECLARE (tflite_inference)
struct _GstTFliteInferenceClass
{
GstBaseTransformClass basetransform;
};
G_END_DECLS
#endif /* __GST_TFLITE_INFERENCE_H__ */

View File

@ -0,0 +1,50 @@
tflite_sources = [
'gsttflite.c',
'gsttfliteinference.c',
'modelinfo.c',
]
tflite_headers = [
'gstfliteinference.h'
]
doc_sources = []
foreach s: tflite_sources + tflite_headers
doc_sources += meson.current_source_dir() / s
endforeach
plugin_sources += {
'tflite': pathsep.join(doc_sources)
}
if get_option('tflite').disabled()
subdir_done()
endif
tensorflow_lite_dep = cc.find_library('tensorflowlite_c',
required: get_option('tflite'))
tensorflow_lite_header_found = cc.has_header('tensorflow/lite/c/c_api.h',
dependencies: tensorflow_lite_dep,
required: get_option('tflite'))
if tensorflow_lite_dep.found() and tensorflow_lite_header_found
tflite_c_args = []
if cc.has_header_symbol('tensorflow/lite/c/c_api.h', 'kTfLiteBFloat16',
dependencies: tensorflow_lite_dep)
tflite_c_args += ['-DTFLITE_HAS_BFLOAT16']
endif
gsttflite = library('gsttflite',
tflite_sources,
c_args : gst_plugins_bad_args + tflite_c_args,
include_directories : [configinc, libsinc],
dependencies : [gstbase_dep, gstvideo_dep, gstanalytics_dep,
tensorflow_lite_dep,libm, gio_dep],
install : true,
install_dir : plugins_install_dir,
)
plugins += [gsttflite]
endif

View File

@ -0,0 +1,374 @@
/*
* GStreamer
* Copyright (C) 2025 Collabora Ltd.
* @author: Olivier Crete <olivier.crete@collabora.com>
*
* modeinfo.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "modelinfo.h"
/**
* SECTION: ModelInfo
*
* The ".modelinfo" files describe the additional metadata for
* a given serialized model file such as a `.tflite`, `.onnx` or `.pte` files.
*
* The ModelInfo files are ini-style. Each section is matched to a
* particular input or output tensor.
*
* The title of the section should ideally match the name of the tensor
* in the model file.
*
* The fields used to match the modelinfo to the model are:
* * `\[title\]`: The name of the tensor, must be unique
* * `dims`: The dimensions as a comma-separated list of ints. -1 matches a dynamic dimension and is a wildcard
* * `dir`: Either "input" or "output"
* * `type`: The data type match #GstTensorDataType, one of:
* * `int4`
* * `int8`
* * `int16`
* * `int32`
* * `int64`
* * `uint4`
* * `uint8`
* * `uint16`
* * `uint32`
* * `uint64`
* * `float16`
* * `float32`
* * `float64`
* * `bfloat16`
*
* Based on these fields, the following metadata is applied to output tensors:
* * `id`: The tensor ID so othr elements can identity it, ideally registered in the [Tensor ID Registry](https://github.com/collabora/tensor-id-registry/blob/main/tensor-id-register.md).
*
* Those fields are applied to input tensors for normalization:
* * `mean`: a double or a comma separated list of floats, one per channel.
* * `stddev`: a double or a comma separated list of floats, one per channel
*
* Those are applied with the formula `(val - mean) / stddev`. Those
* are applied based on a range of [0, 255]. If the input is not in
* the range of [0, 255], the values will be converted before applyign
* them. A mean of 127" means 127 for a `uint8` input or 0 for
* `int8` and 0.5 for `float` inputs.
*
* Other fields are ignored for now.
*
* The API is meant to be used by plugins
*
* Since: 1.28
*/
GST_DEBUG_CATEGORY (analytics_modelinfo_debug);
#define GST_CAT_DEFAULT analytics_modelinfo_debug
static gboolean
key_file_string_matches (GKeyFile * keyfile, const gchar * group,
const gchar * key, const gchar * value)
{
gchar *kf_value = g_key_file_get_string (keyfile, group, key, NULL);
gboolean matches = !g_strcmp0 (kf_value, value);
g_free (kf_value);
return matches;
}
gchar *
modelinfo_get_id (ModelInfo * modelinfo, const gchar * tensor_name)
{
GKeyFile *kf = (GKeyFile *) modelinfo;
return g_key_file_get_string (kf, tensor_name, "id", NULL);
}
GQuark
modelinfo_get_quark_id (ModelInfo * modelinfo, const gchar * tensor_name)
{
GKeyFile *kf = (GKeyFile *) modelinfo;
GQuark q = 0;
gchar *id = g_key_file_get_string (kf, tensor_name, "id", NULL);
if (id)
q = g_quark_from_string (id);
g_free (id);
return q;
}
static gboolean
modelinfo_check_direction (GKeyFile * kf,
const gchar * tensor_name, ModelInfoTensorDirection dir)
{
gchar *value;
gboolean ret = FALSE;
if (dir == MODELINFO_DIRECTION_UNKNOWN)
return TRUE;
value = g_key_file_get_string (kf, tensor_name, "dir", NULL);
if (!value)
return TRUE;
if (dir == MODELINFO_DIRECTION_INPUT)
ret = g_str_equal (value, "input");
if (dir == MODELINFO_DIRECTION_OUTPUT)
ret = g_str_equal (value, "output");
g_free (value);
return ret;
}
static gboolean
modelinfo_validate_internal (GKeyFile * kf, const gchar * tensor_name,
ModelInfoTensorDirection dir, GstTensorDataType data_type, gsize num_dims,
const gsize * dims, gboolean accept_no_dims)
{
gsize kf_dims_length = 0;
gint *kf_dims;
gsize i;
gboolean ret = FALSE;
if (!key_file_string_matches (kf, tensor_name, "type",
gst_tensor_data_type_get_name (data_type)))
return FALSE;
if (!modelinfo_check_direction (kf, tensor_name, dir))
return FALSE;
if (!g_key_file_has_key (kf, tensor_name, "dims", NULL))
return accept_no_dims;
kf_dims = g_key_file_get_integer_list (kf, tensor_name, "dims",
&kf_dims_length, NULL);
if (kf_dims == NULL) {
GST_ERROR ("Invalid model info file, dims in %s is no in the"
" right format", tensor_name);
return FALSE;
}
if (kf_dims_length != num_dims)
goto done;
for (i = 0; i < kf_dims_length; i++) {
/* If the keyfile contains dims < 0, then its a wildcard,
* accept anything */
if (kf_dims[i] < 0)
continue;
/* Dimensions of size "-1" means dynamic, but we didn't accept a wildcard,
* reject it */
if (dims[i] == G_MAXSIZE)
goto done;
if (kf_dims[i] != dims[i])
goto done;
}
ret = TRUE;
done:
g_free (kf_dims);
return ret;
}
static gboolean
modelinfo_validate (ModelInfo * modelinfo, const gchar * tensor_name,
ModelInfoTensorDirection dir, GstTensorDataType data_type, gsize num_dims,
const gsize * dims)
{
GKeyFile *kf = (GKeyFile *) modelinfo;
return modelinfo_validate_internal (kf, tensor_name, dir, data_type,
num_dims, dims, TRUE);
}
static gboolean
modelinfo_has_tensor_name (ModelInfo * modelinfo, const char *tensor_name)
{
GKeyFile *kf = (GKeyFile *) modelinfo;
return g_key_file_has_group (kf, tensor_name);
}
static gchar *
modelinfo_find_tensor_name_by_index (ModelInfo * modelinfo,
ModelInfoTensorDirection dir, gsize index)
{
GKeyFile *kf = (GKeyFile *) modelinfo;
gchar **groups;
gsize i, j;
gchar *tensor_name = NULL;
groups = g_key_file_get_groups (kf, NULL);
for (i = 0, j = 0; groups[i]; i++) {
if (!modelinfo_check_direction (kf, groups[i], dir))
continue;
if (index == j++) {
tensor_name = g_strdup (groups[i]);
break;
}
j++;
}
g_strfreev (groups);
return tensor_name;
}
static gchar *
modelinfo_find_tensor_name_by_dims (ModelInfo * modelinfo,
ModelInfoTensorDirection dir, GstTensorDataType data_type,
gsize num_dims, const gsize * dims)
{
GKeyFile *kf = (GKeyFile *) modelinfo;
gchar **groups;
gsize i;
gchar *tensor_name = NULL;
groups = g_key_file_get_groups (kf, NULL);
for (i = 0; groups[i]; i++) {
if (modelinfo_validate_internal (kf, groups[i], dir, data_type,
num_dims, dims, FALSE)) {
tensor_name = g_strdup (groups[i]);
break;
}
}
g_strfreev (groups);
return tensor_name;
}
ModelInfo *
modelinfo_load (const gchar * model_filename)
{
GKeyFile *kf = g_key_file_new ();
gchar *filename;
gboolean ret;
gchar *last_dot;
g_key_file_set_list_separator (kf, ',');
GST_DEBUG_CATEGORY_INIT (analytics_modelinfo_debug, "modelinfo",
0, "analytics model info");
filename = g_strconcat (model_filename, ".modelinfo", NULL);
ret = g_key_file_load_from_file (kf, filename, G_KEY_FILE_NONE, NULL);
g_free (filename);
if (ret)
return (ModelInfo *) kf;
last_dot = g_utf8_strrchr (model_filename, -1, '.');
if (last_dot && !g_utf8_strchr (last_dot, -1, '/')) {
gchar *tmp = g_strndup (model_filename, last_dot - model_filename);
filename = g_strconcat (tmp, ".modelinfo", NULL);
g_free (tmp);
ret = g_key_file_load_from_file (kf, filename, G_KEY_FILE_NONE, NULL);
g_free (filename);
if (ret)
return (ModelInfo *) kf;
}
g_key_file_free (kf);
return NULL;
}
void
modelinfo_free (ModelInfo * modelinfo)
{
GKeyFile *kf = (GKeyFile *) modelinfo;
g_key_file_free (kf);
}
gchar *
modelinfo_find_tensor_name (ModelInfo * modelinfo,
ModelInfoTensorDirection dir, gsize index, const gchar * in_tensor_name,
GstTensorDataType data_type, gsize num_dims, const gsize * dims)
{
gchar *tensor_name = NULL;
if (in_tensor_name && modelinfo_has_tensor_name (modelinfo, in_tensor_name)) {
if (modelinfo_validate (modelinfo, in_tensor_name, dir, data_type,
num_dims, dims)) {
return g_strdup (in_tensor_name);
}
}
tensor_name = modelinfo_find_tensor_name_by_index (modelinfo, dir, index);
if (tensor_name) {
if (modelinfo_validate (modelinfo, tensor_name, dir, data_type,
num_dims, dims)) {
return tensor_name;
}
g_free (tensor_name);
}
return modelinfo_find_tensor_name_by_dims (modelinfo, dir, data_type,
num_dims, dims);
}
static gsize
modelinfo_get_doubles (ModelInfo * modelinfo, const gchar * tensor_name,
const gchar * param_name, gsize num_channels, gdouble ** out_doubles)
{
GKeyFile *kf = (GKeyFile *) modelinfo;
gdouble *doubles;
gsize length;
doubles = g_key_file_get_double_list (kf, tensor_name, param_name, &length,
NULL);
if (doubles == NULL)
return 0;
if (length != 1 && length != num_channels) {
g_free (doubles);
return 0;
}
*out_doubles = doubles;
return length;
}
gsize
modelinfo_get_normalization_means (ModelInfo * modelinfo,
const gchar * tensor_name, gsize num_channels, gdouble ** means)
{
return modelinfo_get_doubles (modelinfo, tensor_name, "mean",
num_channels, means);
}
gsize
modelinfo_get_normalization_stddevs (ModelInfo * modelinfo,
const gchar * tensor_name, gsize num_channels, gdouble ** means)
{
return modelinfo_get_doubles (modelinfo, tensor_name, "stddev",
num_channels, means);
}

View File

@ -0,0 +1,66 @@
/*
* GStreamer
* Copyright (C) 2025 Collabora Ltd.
* @author: Olivier Crete <olivier.crete@collabora.com>
*
* modeinfo.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#include <glib.h>
#include <gst/analytics/analytics.h>
#pragma once
G_BEGIN_DECLS
typedef enum {
MODELINFO_DIRECTION_UNKNOWN,
MODELINFO_DIRECTION_INPUT,
MODELINFO_DIRECTION_OUTPUT,
} ModelInfoTensorDirection;
typedef struct _ModelInfo ModelInfo;
ModelInfo *
modelinfo_load (const gchar *model_filename);
gchar *
modelinfo_find_tensor_name (ModelInfo * modelinfo,
ModelInfoTensorDirection dir, gsize index, const gchar *in_tensor_name,
GstTensorDataType data_type, gsize num_dims, const gsize * dims);
gchar *
modelinfo_get_id (ModelInfo *modelinfo, const gchar * tensor_name);
GQuark
modelinfo_get_quark_id (ModelInfo *modelinfo, const gchar * tensor_name);
gsize
modelinfo_get_normalization_means (ModelInfo * modelinfo,
const gchar *tensor_name, gsize num_channels, gdouble ** mean);
gsize
modelinfo_get_normalization_stddevs (ModelInfo * modelinfo,
const gchar *tensor_name, gsize num_channels, gdouble ** stddev);
void
modelinfo_free (ModelInfo *model_info);
G_END_DECLS

View File

@ -176,6 +176,7 @@ option('svtav1', type : 'feature', value : 'auto', description : 'Scalable Video
option('svthevcenc', type : 'feature', value : 'auto', description : 'Scalable Video Technology for HEVC encoder plugin')
option('svtjpegxs', type : 'feature', value : 'auto', description : 'Scalable Video Technology for JPEG-XS plugin')
option('teletext', type : 'feature', value : 'auto', description : 'Teletext plugin')
option('tflite', type : 'feature', value : 'auto', description : 'TensorFlow Lite (LiteRT) plugin')
option('tinyalsa', type : 'feature', value : 'auto', description : 'TinyALSA plugin')
option('transcode', type : 'feature', value : 'auto', description : 'Transcode plugin')
option('ttml', type : 'feature', value : 'auto', description : 'TTML subtitle parser and renderer plugin')