1001 lines
31 KiB
C

/*
* GStreamer
* Copyright (C) 2024 Collabora Ltd.
*
* gsttfliteinference.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
/**
* SECTION:element-tfliteinference
* @short_description: Run TFLITE inference model on video buffers
*
* This element can apply an TFLITE model to video buffers. It attaches
* the tensor output to the buffer as a @ref GstTensorMeta.
*
* To install TFLITE on your system, follow the instructions in the
* README.md in with this plugin.
*
* ## Example launch command:
*
* GST_DEBUG=ssdobjectdetector:5 \
* gst-launch-1.0 filesrc location=tflite-models/images/bus.jpg ! \
* jpegdec ! videoconvert ! tfliteinference model-file=tflite-models/models/ssd_mobilenet_v1_coco.tflite ! \
* ssdobjectdetector label-file=tflite-models/labels/COCO_classes.txt ! videoconvert ! imagefreeze ! autovideosink
*
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include <gst/gst.h>
#include <gst/video/video.h>
#include "gsttfliteinference.h"
#include "modelinfo.h"
#include <tensorflow/lite/c/common.h>
#define DEFAULT_MODEL_FILE ""
#define DEFAULT_THREADS 0
/*
* GstTFliteInference:
*
* @model_file model file
* @tflite_client opaque pointer to TFLITE client
* @tflite_disabled true if inference is disabled
* @video_info @ref GstVideoInfo of sink caps
*/
typedef struct _GstTFliteInferencePrivate
{
GstBaseTransform basetransform;
gchar *model_file;
gsize numberOfThreads;
gchar *vxdelegate;
gboolean planar;
GPtrArray *tensor_templates;
TfLiteInterpreter *interpreter;
TfLiteInterpreterOptions *interpreter_options;
TfLiteModel *model;
gboolean tflite_disabled;
GstVideoInfo video_info;
guint8 *dest;
GstCaps *model_caps;
gint channels;
gdouble *means;
gdouble *stddevs;
} GstTFliteInferencePrivate;
GST_DEBUG_CATEGORY (tflite_inference_debug);
#define GST_CAT_DEFAULT tflite_inference_debug
GST_ELEMENT_REGISTER_DEFINE (tflite_inference, "tfliteinference",
GST_RANK_NONE, GST_TYPE_TFLITE_INFERENCE);
/* GstTFliteInference properties */
enum
{
PROP_0,
PROP_MODEL_FILE,
PROP_THREADS,
};
#define VIDEO_CAPS GST_VIDEO_CAPS_MAKE ("{ RGB, RGBA, BGR, BGRA }")
static GstStaticPadTemplate gst_tflite_inference_src_template =
GST_STATIC_PAD_TEMPLATE ("src",
GST_PAD_SRC,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (VIDEO_CAPS)
);
static GstStaticPadTemplate gst_tflite_inference_sink_template =
GST_STATIC_PAD_TEMPLATE ("sink",
GST_PAD_SINK,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (VIDEO_CAPS)
);
static gboolean gst_tflite_inference_start (GstBaseTransform * trans);
static gboolean gst_tflite_inference_stop (GstBaseTransform * trans);
static void gst_tflite_inference_set_property (GObject * object,
guint prop_id, const GValue * value, GParamSpec * pspec);
static void gst_tflite_inference_get_property (GObject * object,
guint prop_id, GValue * value, GParamSpec * pspec);
static void gst_tflite_inference_finalize (GObject * object);
static GstFlowReturn gst_tflite_inference_transform_ip (GstBaseTransform *
trans, GstBuffer * buf);
static gboolean gst_tflite_inference_process (GstBaseTransform * trans,
GstBuffer * buf);
static GstCaps *gst_tflite_inference_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
static gboolean
gst_tflite_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps);
G_DEFINE_TYPE_WITH_PRIVATE (GstTFliteInference, gst_tflite_inference,
GST_TYPE_BASE_TRANSFORM);
static void
gst_tflite_inference_class_init (GstTFliteInferenceClass * klass)
{
GObjectClass *gobject_class = (GObjectClass *) klass;
GstElementClass *element_class = (GstElementClass *) klass;
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
GST_DEBUG_CATEGORY_INIT (tflite_inference_debug, "tfliteinference",
0, "tflite_inference");
gobject_class->set_property = gst_tflite_inference_set_property;
gobject_class->get_property = gst_tflite_inference_get_property;
gobject_class->finalize = gst_tflite_inference_finalize;
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_MODEL_FILE,
g_param_spec_string ("model-file",
"TFLITE model file", "TFLITE model file", DEFAULT_MODEL_FILE,
(GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_THREADS,
g_param_spec_int ("threads",
"Number of Threads",
"Set the number of threads to be used by the TFLITE inference (-1 for auto)",
-1, G_MAXINT, DEFAULT_THREADS,
(GParamFlags) (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "tfliteinference",
"Filter/Effect",
"Apply neural network to video frames and create tensor output",
"Denis Shimizu <denis.shimizu@collabora.com>, "
"Aaron Boxer <aaron.boxer@collabora.com>,"
"Daniel Morin <daniel.morin@collabora.com>");
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_tflite_inference_sink_template));
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_tflite_inference_src_template));
basetransform_class->transform_ip =
GST_DEBUG_FUNCPTR (gst_tflite_inference_transform_ip);
basetransform_class->transform_caps =
GST_DEBUG_FUNCPTR (gst_tflite_inference_transform_caps);
basetransform_class->set_caps =
GST_DEBUG_FUNCPTR (gst_tflite_inference_set_caps);
basetransform_class->start = GST_DEBUG_FUNCPTR (gst_tflite_inference_start);
basetransform_class->stop = GST_DEBUG_FUNCPTR (gst_tflite_inference_stop);
}
static gboolean
gst_tflite_inference_has_session (GstTFliteInference * self)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
return priv->interpreter != NULL;
}
static void
gst_tflite_inference_init (GstTFliteInference * self)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
priv->numberOfThreads = DEFAULT_THREADS;
priv->tensor_templates = g_ptr_array_new_with_free_func ((GDestroyNotify)
gst_tensor_free);
priv->tflite_disabled = TRUE;
}
static void
gst_tflite_inference_finalize (GObject * object)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (object);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
g_free (priv->model_file);
g_ptr_array_unref (priv->tensor_templates);
G_OBJECT_CLASS (gst_tflite_inference_parent_class)->finalize (object);
}
static void
gst_tflite_inference_set_property (GObject * object, guint prop_id,
const GValue * value, GParamSpec * pspec)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (object);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
const gchar *filename;
switch (prop_id) {
case PROP_MODEL_FILE:
filename = g_value_get_string (value);
if (filename
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
if (priv->model_file)
g_free (priv->model_file);
priv->model_file = g_strdup (filename);
priv->tflite_disabled = FALSE;
} else {
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
}
break;
case PROP_THREADS:
priv->numberOfThreads = g_value_get_int (value);
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static void
gst_tflite_inference_get_property (GObject * object, guint prop_id,
GValue * value, GParamSpec * pspec)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (object);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
switch (prop_id) {
case PROP_MODEL_FILE:
g_value_set_string (value, priv->model_file);
break;
case PROP_THREADS:
g_value_set_int (value, priv->numberOfThreads);
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static GstTensorDataType
gst_tflite_convert_data_type (TfLiteType type)
{
switch (type) {
case kTfLiteFloat32:
return GST_TENSOR_DATA_TYPE_FLOAT32;
case kTfLiteInt32:
return GST_TENSOR_DATA_TYPE_INT32;
case kTfLiteUInt8:
return GST_TENSOR_DATA_TYPE_UINT8;
case kTfLiteInt64:
return GST_TENSOR_DATA_TYPE_INT64;
case kTfLiteInt16:
return GST_TENSOR_DATA_TYPE_INT16;
case kTfLiteInt8:
return GST_TENSOR_DATA_TYPE_INT8;
case kTfLiteFloat16:
return GST_TENSOR_DATA_TYPE_FLOAT16;
case kTfLiteFloat64:
return GST_TENSOR_DATA_TYPE_FLOAT64;
case kTfLiteUInt64:
return GST_TENSOR_DATA_TYPE_UINT64;
case kTfLiteUInt32:
return GST_TENSOR_DATA_TYPE_UINT32;
case kTfLiteUInt16:
return GST_TENSOR_DATA_TYPE_UINT16;
case kTfLiteInt4:
return GST_TENSOR_DATA_TYPE_INT4;
#ifdef TFLITE_HAS_BFLOAT16
case kTfLiteBFloat16:
return GST_TENSOR_DATA_TYPE_BFLOAT16;
#endif
default:
GST_FIXME ("GstTensorDataType currently does not have a mapping \
for this type.");
g_assert_not_reached ();
}
}
static gboolean
convert_tensor_info (const TfLiteTensor * tflite_tensor,
const gchar ** tname, GstTensorDataType * data_type,
gsize * dims_count, gsize ** out_dims)
{
gsize j;
gsize *dims;
if (tname)
*tname = TfLiteTensorName (tflite_tensor);
*dims_count = TfLiteTensorNumDims (tflite_tensor);
if (*dims_count == 0)
return FALSE;
dims = *out_dims = (gsize *) g_malloc0_n (*dims_count, sizeof (gsize));
if (tflite_tensor->dims_signature && tflite_tensor->dims_signature->size) {
for (j = 0; j < *dims_count; j++) {
if (tflite_tensor->dims_signature->data[j] < 0)
dims[j] = G_MAXSIZE;
else
dims[j] = tflite_tensor->dims_signature->data[j];
}
} else {
for (j = 0; j < *dims_count; j++)
dims[j] = TfLiteTensorDim (tflite_tensor, j);
}
*data_type = gst_tflite_convert_data_type (TfLiteTensorType (tflite_tensor));
return TRUE;
}
static gchar *
build_dims_str (gsize dims_count, gsize * dims)
{
GString *dims_gstr = g_string_new ("");
gsize j;
if (dims_count == 0)
goto done;
if (dims[0] == G_MAXSIZE)
g_string_append (dims_gstr, "-1");
else
g_string_append_printf (dims_gstr, "%zu", dims[0]);
for (j = 1; j < dims_count; j++)
if (dims[j] == G_MAXSIZE)
g_string_append (dims_gstr, ",-1");
else
g_string_append_printf (dims_gstr, ",%zu", dims[j]);
done:
return g_string_free (dims_gstr, FALSE);
}
static gboolean
_get_input_params (GstTFliteInference * self, GstTensorDataType * data_type,
gint * width, gint * height, const gchar ** gst_format,
gint * channels, gboolean * planar)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
const TfLiteTensor *input_tensor;
gint i_size = TfLiteInterpreterGetInputTensorCount (priv->interpreter);
gsize dims_count;
gsize *dims = NULL;
if (i_size != 1) {
GST_ERROR_OBJECT (self, "Currently only support model with a single"
" input tensor, but model has %d", i_size);
goto reject;
}
input_tensor = TfLiteInterpreterGetInputTensor (priv->interpreter, 0);
if (!convert_tensor_info (input_tensor, NULL, data_type, &dims_count, &dims)) {
GST_ERROR_OBJECT (self, "Input tensor has no dimensions, rejecting");
goto reject;
}
if (dims_count < 2 || dims_count > 4) {
GST_ERROR_OBJECT (self,
"Don't know how to interpret tensors with %zu dimensions", dims_count);
goto reject;
}
*planar = FALSE;
switch (dims_count) {
case 2:
*gst_format = "GRAY8";
*height = dims[0];
*width = dims[1];
break;
case 3:
if (dims[0] == 1 || dims[0] == 3) {
*channels = dims[0];
if (dims[0] == 1) {
*gst_format = "GRAY8";
} else {
*gst_format = "RGBP";
*planar = TRUE;
}
*height = dims[1];
*width = dims[2];
} else if (dims[2] == 1 || dims[2] == 3) {
*channels = dims[2];
if (dims[2] == 1)
*gst_format = "GRAY";
else
*gst_format = "RGB";
*height = dims[0];
*width = dims[1];
} else {
GST_ERROR_OBJECT (self, "Don't know how to interpret dims");
goto reject;
}
break;
case 4:
/* Assuming dims[0] is a batch */
if (dims[1] == 1 || dims[1] == 3) {
*channels = dims[1];
*planar = TRUE;
*height = dims[2];
*width = dims[3];
} else if (dims[3] == 1 || dims[3] == 3) {
*channels = dims[3];
*height = dims[1];
*width = dims[2];
} else {
GST_ERROR_OBJECT (self, "Don't know how to interpret dims");
goto reject;
}
if (*channels == 1) {
*gst_format = "GRAY8";
*planar = FALSE;
} else if (*channels == 3) {
if (*planar)
*gst_format = "RGBP";
else
*gst_format = "RGB";
} else {
g_assert_not_reached ();
}
break;
}
g_free (dims);
return TRUE;
reject:
g_free (dims);
return FALSE;
}
static gboolean
gst_tflite_inference_start (GstBaseTransform * trans)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
gboolean ret = FALSE;
ModelInfo *modelinfo = NULL;
gint i_size, o_size;
GstTFliteInferenceClass *klass = GST_TFLITE_INFERENCE_GET_CLASS (self);
GST_OBJECT_LOCK (self);
if (gst_tflite_inference_has_session (self)) {
ret = TRUE;
goto done;
}
if (priv->model_file == NULL) {
GST_ERROR_OBJECT (self, "model-file property not set");
goto done;
}
priv->model = TfLiteModelCreateFromFile (priv->model_file);
if (!priv->model) {
GST_ERROR_OBJECT (self, "Failed to mmap model %s", priv->model_file);
goto error;
}
GST_DEBUG_OBJECT (self, "Loaded model %s", priv->model_file);
priv->interpreter_options = TfLiteInterpreterOptionsCreate ();
if (priv->numberOfThreads != 0) {
TfLiteInterpreterOptionsSetNumThreads (priv->interpreter_options,
priv->numberOfThreads);
}
if (klass->update_options)
if (!klass->update_options (self, priv->interpreter_options))
goto error;
priv->interpreter = TfLiteInterpreterCreate (priv->model,
priv->interpreter_options);
if (!priv->interpreter) {
GST_ERROR_OBJECT (self, "Failed to construct interpreter");
goto error;
}
modelinfo = modelinfo_load (priv->model_file);
if (!modelinfo) {
GST_ERROR_OBJECT (self, "Can't find modelinfo for %s", priv->model_file);
goto error;
}
i_size = TfLiteInterpreterGetInputTensorCount (priv->interpreter);
if (i_size != 1) {
GST_ERROR_OBJECT (self, "Currently only support model with a single"
" input tensor, but model has %d", i_size);
goto error;
}
{
const guint i = 0;
const TfLiteTensor *tflite_tensor =
TfLiteInterpreterGetInputTensor (priv->interpreter, i);
const gchar *tname;
GstTensorDataType data_type;
gsize dims_count;
gsize *dims;
gchar *tensor_name = NULL;
gint width = 0, height = 0;
const gchar *gst_format = NULL;
guint num_means, num_stddevs;
if (!_get_input_params (self, &data_type, &width, &height, &gst_format,
&priv->channels, &priv->planar)) {
GST_ERROR_OBJECT (self, "Failed to get parameters");
goto error;
}
if (!convert_tensor_info (tflite_tensor, &tname, &data_type,
&dims_count, &dims)) {
GST_ERROR_OBJECT (self, "Rejecting input_tensor[%d]:%s with no dims",
i, tname);
goto error;
}
tensor_name = modelinfo_find_tensor_name (modelinfo,
MODELINFO_DIRECTION_INPUT, i, tname, data_type, dims_count, dims);
if (tensor_name == NULL) {
gchar *dims_str = build_dims_str (dims_count, dims);
GST_DEBUG_OBJECT (self,
"Model info file doesn't contain info for input_tensor[%u]:%s matching the"
" type %s and dims %s", i, tname,
gst_tensor_data_type_get_name (data_type), dims_str);
g_free (dims);
g_free (dims_str);
} else {
num_means = modelinfo_get_normalization_means (modelinfo,
tensor_name, priv->channels, &priv->means);
if (num_means != priv->channels) {
priv->means = g_renew (gdouble, priv->means, priv->channels);
for (guint j = 1; j < priv->channels; j++)
priv->means[j] = priv->means[0];
}
num_stddevs = modelinfo_get_normalization_stddevs (modelinfo,
tensor_name, priv->channels, &priv->stddevs);
if (num_stddevs != priv->channels) {
priv->stddevs = g_renew (gdouble, priv->stddevs, priv->channels);
for (guint j = 1; j < priv->channels; j++)
priv->stddevs[j] = priv->stddevs[0];
}
}
gst_clear_caps (&priv->model_caps);
priv->model_caps = gst_caps_new_empty_simple ("video/x-raw");
if (width && height)
gst_caps_set_simple (priv->model_caps, "width", G_TYPE_INT, width,
"height", G_TYPE_INT, height, NULL);
if (data_type == GST_TENSOR_DATA_TYPE_UINT8 && gst_format &&
priv->means == NULL && priv->stddevs == NULL)
gst_caps_set_simple (priv->model_caps, "format", G_TYPE_STRING,
gst_format, NULL);
g_free (tensor_name);
}
if (TfLiteInterpreterAllocateTensors (priv->interpreter) != kTfLiteOk) {
GST_ERROR_OBJECT (self, "Failed to allocate tensors");
goto error;
}
o_size = TfLiteInterpreterGetOutputTensorCount (priv->interpreter);
for (guint i = 0; i < o_size; i++) {
const TfLiteTensor *tflite_tensor =
TfLiteInterpreterGetOutputTensor (priv->interpreter, i);
const gchar *tname;
GstTensorDataType data_type;
gsize dims_count;
gsize *dims;
gchar *tensor_name = NULL;
if (!convert_tensor_info (tflite_tensor, &tname, &data_type,
&dims_count, &dims)) {
GST_WARNING_OBJECT (self, "Skipping output_tensor[%d]:%s with no dims",
i, tname);
continue;
}
tensor_name = modelinfo_find_tensor_name (modelinfo,
MODELINFO_DIRECTION_OUTPUT, i, tname, data_type, dims_count, dims);
gchar *dims_str = build_dims_str (dims_count, dims);
if (tensor_name == NULL) {
GST_ERROR_OBJECT (self,
"Model info file doesn't contain info for output_tensor[%u]:%s matching the"
" type %s and dims %s", i, tname,
gst_tensor_data_type_get_name (data_type), dims_str);
g_free (dims);
g_free (dims_str);
g_ptr_array_set_size (priv->tensor_templates, 0);
goto error;
}
GstTensor *t = gst_tensor_alloc (dims_count);
gchar *id = modelinfo_get_id (modelinfo, tensor_name);
GST_DEBUG_OBJECT (self, "Mapping output_tensor[%d]:%s of type %s and"
" dims %s to id %s", i, tname,
gst_tensor_data_type_get_name (data_type), dims_str, id);
g_free (id);
g_free (dims_str);
t->id = modelinfo_get_quark_id (modelinfo, tensor_name);
t->layout = GST_TENSOR_LAYOUT_CONTIGUOUS;
t->data_type = data_type;
t->dims_order = GST_TENSOR_DIM_ORDER_ROW_MAJOR;
memcpy (t->dims, dims, sizeof (gsize) * t->num_dims);
g_free (dims);
g_ptr_array_add (priv->tensor_templates, t);
g_free (tensor_name);
}
TfLiteTensor *itensor = TfLiteInterpreterGetInputTensor (priv->interpreter,
0);
if (TfLiteTensorType (itensor) == kTfLiteFloat32) {
GST_DEBUG_OBJECT (self, "Floating point Tensorflow Lite Model");
}
ret = TRUE;
done:
if (modelinfo)
modelinfo_free (modelinfo);
GST_OBJECT_UNLOCK (self);
return ret;
error:
GST_ERROR_OBJECT (self,
"Unable to create TFLITE session. Inference is disabled.");
GST_BASE_TRANSFORM_GET_CLASS (self)->stop (trans);
goto done;
}
static gboolean
gst_tflite_inference_stop (GstBaseTransform * trans)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
if (priv->interpreter)
TfLiteInterpreterDelete (priv->interpreter);
priv->interpreter = NULL;
if (priv->interpreter_options)
TfLiteInterpreterOptionsDelete (priv->interpreter_options);
priv->interpreter_options = NULL;
if (priv->model)
TfLiteModelDelete (priv->model);
priv->model = NULL;
gst_clear_caps (&priv->model_caps);
g_ptr_array_set_size (priv->tensor_templates, 0);
return TRUE;
}
static GstCaps *
gst_tflite_inference_transform_caps (GstBaseTransform * trans,
GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
GstCaps *other_caps;
if (priv->model_caps == NULL) {
other_caps = gst_caps_ref (caps);
goto done;
}
GST_DEBUG_OBJECT (self, "Applying caps restrictions: %" GST_PTR_FORMAT,
priv->model_caps);
other_caps = gst_caps_intersect_full (caps, priv->model_caps,
GST_CAPS_INTERSECT_FIRST);
done:
if (filter_caps) {
GstCaps *tmp = gst_caps_intersect_full (other_caps, filter_caps,
GST_CAPS_INTERSECT_FIRST);
gst_caps_replace (&other_caps, tmp);
gst_caps_unref (tmp);
}
return other_caps;
}
static gboolean
gst_tflite_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
if (!gst_video_info_from_caps (&priv->video_info, incaps)) {
GST_ERROR_OBJECT (self, "Failed to parse caps");
return FALSE;
}
return TRUE;
}
static GstFlowReturn
gst_tflite_inference_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
{
if (!gst_base_transform_is_passthrough (trans)
&& !gst_tflite_inference_process (trans, buf)) {
GST_ELEMENT_ERROR (trans, STREAM, FAILED,
(NULL), ("TFLITE inference failed"));
return GST_FLOW_ERROR;
}
return GST_FLOW_OK;
}
#define _convert_image_remove_alpha(Type, dst, srcPtr, \
srcSamplesPerPixel, stride, means, stddevs) \
G_STMT_START { \
size_t destIndex = 0; \
Type tmp; \
\
if (!priv->planar) { \
for (int32_t j = 0; j < dstHeight; ++j) { \
for (int32_t i = 0; i < dstWidth; ++i) { \
for (int32_t k = 0; k < dstChannels; ++k) { \
tmp = *srcPtr[k]; \
tmp += means[k]; \
dst[destIndex++] = (Type)(tmp / stddevs[k]); \
srcPtr[k] += srcSamplesPerPixel; \
} \
} \
/* correct for stride */ \
for (uint32_t k = 0; k < 3; ++k) \
srcPtr[k] += stride - srcSamplesPerPixel * dstWidth; \
} \
} else { \
size_t frameSize = dstWidth * dstHeight; \
Type *destPtr[3] = { dst, dst + frameSize, dst + 2 * frameSize }; \
for (int32_t j = 0; j < dstHeight; ++j) { \
for (int32_t i = 0; i < dstWidth; ++i) { \
for (int32_t k = 0; k < dstChannels; ++k) { \
tmp = *srcPtr[k]; \
tmp += means[k]; \
destPtr[k][destIndex] = (Type)(tmp / stddevs[k]); \
srcPtr[k] += srcSamplesPerPixel; \
} \
destIndex++; \
} \
/* correct for stride */ \
for (uint32_t k = 0; k < 3; ++k) \
srcPtr[k] += stride - srcSamplesPerPixel * dstWidth; \
} \
} \
} \
G_STMT_END;
static void
gst_tflite_inference_convert_image_remove_alpha_u8 (GstTFliteInference * self,
guint8 * dst, gint dstWidth, gint dstHeight, gint dstChannels,
guint8 ** srcPtr, guint8 srcSamplesPerPixel,
guint32 stride, const gdouble * means, const gdouble * stddevs)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
static const gdouble zeros[] = { 0, 0, 0, 0 };
static const gdouble ones[] = { 1.0, 1.0, 1.0, 1.0 };
if (means == NULL)
means = zeros;
if (stddevs == NULL)
stddevs = ones;
_convert_image_remove_alpha (guint8, dst, srcPtr, srcSamplesPerPixel,
stride, means, stddevs);
}
static void
gst_tflite_inference_convert_image_remove_alpha_f32 (GstTFliteInference * self,
gfloat * dst, gint dstWidth, gint dstHeight, gint dstChannels,
guint8 ** srcPtr, guint8 srcSamplesPerPixel,
guint32 stride, const gdouble * means, const gdouble * stddevs)
{
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
static const gdouble zeros[] = { 0, 0, 0, 0 };
static const gdouble two_five_fives[] = { 255.0, 255.0, 255.0, 255.0 };
if (means == NULL)
means = zeros;
if (stddevs == NULL)
stddevs = two_five_fives;
_convert_image_remove_alpha (gfloat, dst, srcPtr, srcSamplesPerPixel,
stride, means, stddevs);
}
static gboolean
gst_tflite_inference_process (GstBaseTransform * trans, GstBuffer * buf)
{
GstTFliteInference *self = GST_TFLITE_INFERENCE (trans);
GstTFliteInferencePrivate *priv =
gst_tflite_inference_get_instance_private (self);
GstMapInfo info;
guint8 *srcPtr[3];
gsize srcSamplesPerPixel = 3;
GstTensorDataType datatype;
if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
// <==
srcPtr[0] = info.data;
srcPtr[1] = info.data + 1;
srcPtr[2] = info.data + 2;
switch (priv->video_info.finfo->format) {
case GST_VIDEO_FORMAT_RGBA:
srcSamplesPerPixel = 4;
break;
case GST_VIDEO_FORMAT_BGRA:
srcSamplesPerPixel = 4;
srcPtr[0] = info.data + 2;
srcPtr[1] = info.data + 1;
srcPtr[2] = info.data + 0;
break;
case GST_VIDEO_FORMAT_ARGB:
srcSamplesPerPixel = 4;
srcPtr[0] = info.data + 1;
srcPtr[1] = info.data + 2;
srcPtr[2] = info.data + 3;
break;
case GST_VIDEO_FORMAT_ABGR:
srcSamplesPerPixel = 4;
srcPtr[0] = info.data + 3;
srcPtr[1] = info.data + 2;
srcPtr[2] = info.data + 1;
break;
case GST_VIDEO_FORMAT_BGR:
srcPtr[0] = info.data + 2;
srcPtr[1] = info.data + 1;
srcPtr[2] = info.data + 0;
break;
default:
break;
}
TfLiteTensor *tensor = TfLiteInterpreterGetInputTensor (priv->interpreter,
0);
guint width = GST_VIDEO_INFO_WIDTH (&priv->video_info);
guint height = GST_VIDEO_INFO_HEIGHT (&priv->video_info);
guint32 stride = priv->video_info.stride[0];
guint channels;
if (GST_VIDEO_INFO_IS_GRAY (&priv->video_info)) {
channels = 1;
} else if (GST_VIDEO_INFO_IS_RGB (&priv->video_info)) {
channels = 3;
} else {
g_assert_not_reached ();
}
datatype = gst_tflite_convert_data_type (TfLiteTensorType (tensor));
switch (datatype) {
case GST_TENSOR_DATA_TYPE_UINT8:{
uint8_t *dest = (uint8_t *) TfLiteTensorData (tensor);
if (dest == NULL)
return false;
gst_tflite_inference_convert_image_remove_alpha_u8 (self,
dest, width, height, channels, srcPtr,
srcSamplesPerPixel, stride, priv->means, priv->stddevs);
break;
}
case GST_TENSOR_DATA_TYPE_FLOAT32:{
float *dest = (float *) TfLiteTensorData (tensor);
if (dest == NULL)
return false;
gst_tflite_inference_convert_image_remove_alpha_f32 (self, dest,
width, height, channels, srcPtr,
srcSamplesPerPixel, stride, priv->means, priv->stddevs);
break;
}
default:{
GST_ERROR_OBJECT (self, "Data type not handled");
return false;
}
break;
}
/* Run inference */
if (TfLiteInterpreterInvoke (priv->interpreter) != kTfLiteOk) {
GST_ERROR_OBJECT (self, "Failed to invoke tflite!");
return false;
}
gsize num_tensors =
TfLiteInterpreterGetOutputTensorCount (priv->interpreter);
g_assert (num_tensors == priv->tensor_templates->len);
GstTensor **tensors =
(GstTensor **) g_malloc0_n (num_tensors, sizeof (gpointer));
for (size_t i = 0; i < num_tensors; i++) {
const TfLiteTensor *output_tensor =
TfLiteInterpreterGetOutputTensor (priv->interpreter, i);
tensors[i] = gst_tensor_alloc (TfLiteTensorNumDims (output_tensor));
memcpy (tensors[i], g_ptr_array_index (priv->tensor_templates, i),
sizeof (GstTensor));
tensors[i]->num_dims = TfLiteTensorNumDims (output_tensor);
for (gsize j = 0; j < tensors[i]->num_dims; j++)
tensors[i]->dims[j] = TfLiteTensorDim (output_tensor, j);;
tensors[i]->data =
gst_buffer_new_allocate (NULL, TfLiteTensorByteSize (output_tensor),
NULL);
gst_buffer_fill (tensors[i]->data, 0, TfLiteTensorData (output_tensor),
TfLiteTensorByteSize (output_tensor));
}
GstTensorMeta *tmeta = gst_buffer_add_tensor_meta (buf);
gst_tensor_meta_set (tmeta, num_tensors, tensors);
if (!tmeta)
return FALSE;
GST_TRACE_OBJECT (trans, "Num tensors: %zu", tmeta->num_tensors);
gst_buffer_unmap (buf, &info);
}
return TRUE;
}