ssdobjectdetector: Validate tensor type and dimensions
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/9419>
This commit is contained in:
parent
561e2b28af
commit
fe0d7f249d
@ -306,48 +306,80 @@ gst_ssd_object_detector_get_property (GObject * object, guint prop_id,
|
||||
}
|
||||
}
|
||||
|
||||
static GstTensorMeta *
|
||||
gst_ssd_object_detector_get_tensor_meta (GstSsdObjectDetector * object_detector,
|
||||
GstBuffer * buf)
|
||||
static gboolean
|
||||
gst_ssd_object_detector_get_tensors (GstSsdObjectDetector * object_detector,
|
||||
GstBuffer * buf, const GstTensor ** classes_tensor,
|
||||
const GstTensor ** numdetect_tensor, const GstTensor ** scores_tensor,
|
||||
const GstTensor ** boxes_tensor)
|
||||
{
|
||||
GstMeta *meta = NULL;
|
||||
gpointer iter_state = NULL;
|
||||
static const gsize BOXES_DIMS[] = { 1, G_MAXSIZE, 4 };
|
||||
static const gsize NUM_DETECT_DIMS[] = { 1 };
|
||||
static const gsize SCORES_CLASSES_DIMS[] = { 1, G_MAXSIZE };
|
||||
|
||||
if (!gst_buffer_get_meta (buf, GST_TENSOR_META_API_TYPE)) {
|
||||
GST_DEBUG_OBJECT (object_detector,
|
||||
"missing tensor meta from buffer %" GST_PTR_FORMAT, buf);
|
||||
return NULL;
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
// find object detector meta
|
||||
|
||||
while ((meta = gst_buffer_iterate_meta_filtered (buf, &iter_state,
|
||||
GST_TENSOR_META_API_TYPE))) {
|
||||
GstTensorMeta *tensor_meta = (GstTensorMeta *) meta;
|
||||
/* SSD model must have either 3 or 4 output tensor nodes: 4 if there is a label node,
|
||||
* and only 3 if there is no label */
|
||||
if (tensor_meta->num_tensors != 3 && tensor_meta->num_tensors != 4)
|
||||
GstTensorMeta *tmeta = (GstTensorMeta *) meta;
|
||||
|
||||
*boxes_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES),
|
||||
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 3,
|
||||
BOXES_DIMS);
|
||||
if (*boxes_tensor == NULL)
|
||||
*boxes_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES),
|
||||
GST_TENSOR_DATA_TYPE_UINT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 3,
|
||||
BOXES_DIMS);
|
||||
if (*boxes_tensor == NULL)
|
||||
continue;
|
||||
|
||||
gint boxesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES));
|
||||
gint scoresIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
|
||||
gint numDetectionsIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
|
||||
gint clasesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||
|
||||
if (boxesIndex == -1 || scoresIndex == -1 || numDetectionsIndex == -1)
|
||||
*scores_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES),
|
||||
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2,
|
||||
SCORES_CLASSES_DIMS);
|
||||
if (*scores_tensor == NULL)
|
||||
*scores_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES),
|
||||
GST_TENSOR_DATA_TYPE_UINT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2,
|
||||
SCORES_CLASSES_DIMS);
|
||||
if (*scores_tensor == NULL)
|
||||
continue;
|
||||
|
||||
if (tensor_meta->num_tensors == 4 && clasesIndex == -1)
|
||||
*numdetect_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS),
|
||||
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1,
|
||||
NUM_DETECT_DIMS);
|
||||
if (*numdetect_tensor == NULL)
|
||||
*numdetect_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS),
|
||||
GST_TENSOR_DATA_TYPE_UINT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1,
|
||||
NUM_DETECT_DIMS);
|
||||
if (*numdetect_tensor == NULL)
|
||||
continue;
|
||||
|
||||
return tensor_meta;
|
||||
*classes_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES),
|
||||
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2,
|
||||
SCORES_CLASSES_DIMS);
|
||||
if (*classes_tensor == NULL)
|
||||
*classes_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES),
|
||||
GST_TENSOR_DATA_TYPE_UINT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2,
|
||||
SCORES_CLASSES_DIMS);
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
@ -380,7 +412,7 @@ gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
|
||||
|
||||
#define DEFINE_GET_FUNC(TYPE, MAX) \
|
||||
static gboolean \
|
||||
get_ ## TYPE ## _at_index (GstTensor *tensor, GstMapInfo *map, \
|
||||
get_ ## TYPE ## _at_index (const GstTensor *tensor, GstMapInfo *map, \
|
||||
guint index, TYPE * out) \
|
||||
{ \
|
||||
switch (tensor->data_type) { \
|
||||
@ -405,18 +437,16 @@ gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
|
||||
return TRUE; \
|
||||
}
|
||||
|
||||
DEFINE_GET_FUNC (guint32, UINT32_MAX)
|
||||
DEFINE_GET_FUNC (float, FLOAT_MAX)
|
||||
DEFINE_GET_FUNC (guint32, UINT32_MAX);
|
||||
DEFINE_GET_FUNC (float, FLOAT_MAX);
|
||||
#undef DEFINE_GET_FUNC
|
||||
static void
|
||||
extract_bounding_boxes (GstSsdObjectDetector * self, gsize w, gsize h,
|
||||
GstAnalyticsRelationMeta * rmeta, GstTensorMeta * tmeta)
|
||||
{
|
||||
gint classes_index;
|
||||
gint boxes_index;
|
||||
gint scores_index;
|
||||
gint numdetect_index;
|
||||
|
||||
static void
|
||||
extract_bounding_boxes (GstSsdObjectDetector * self, gsize w, gsize h,
|
||||
GstAnalyticsRelationMeta * rmeta, const GstTensor * classes_tensor,
|
||||
const GstTensor * numdetect_tensor, const GstTensor * scores_tensor,
|
||||
const GstTensor * boxes_tensor)
|
||||
{
|
||||
GstMapInfo boxes_map = GST_MAP_INFO_INIT;
|
||||
GstMapInfo numdetect_map = GST_MAP_INFO_INIT;
|
||||
GstMapInfo scores_map = GST_MAP_INFO_INIT;
|
||||
@ -424,57 +454,49 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX)
|
||||
|
||||
guint num_detections = 0;
|
||||
|
||||
classes_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||
numdetect_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
|
||||
scores_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
|
||||
boxes_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES));
|
||||
|
||||
if (numdetect_index == -1 || scores_index == -1 || numdetect_index == -1) {
|
||||
if (numdetect_tensor == NULL || scores_tensor == NULL || boxes_tensor == NULL) {
|
||||
GST_WARNING ("Missing tensor data expected for SSD model");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!gst_buffer_map (tmeta->tensors[numdetect_index]->data, &numdetect_map,
|
||||
GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||
numdetect_index);
|
||||
if (!gst_buffer_map (numdetect_tensor->data, &numdetect_map, GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map numdetect tensor memory");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (!gst_buffer_map (tmeta->tensors[boxes_index]->data, &boxes_map,
|
||||
GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||
boxes_index);
|
||||
if (!gst_buffer_map (boxes_tensor->data, &boxes_map, GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map boxes tensor memory");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (!gst_buffer_map (tmeta->tensors[scores_index]->data, &scores_map,
|
||||
GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||
scores_index);
|
||||
if (!gst_buffer_map (scores_tensor->data, &scores_map, GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map scores tensor memory");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (classes_index != -1 &&
|
||||
!gst_buffer_map (tmeta->tensors[classes_index]->data, &classes_map,
|
||||
GST_MAP_READ)) {
|
||||
GST_DEBUG_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||
classes_index);
|
||||
if (classes_tensor &&
|
||||
!gst_buffer_map (classes_tensor->data, &classes_map, GST_MAP_READ)) {
|
||||
GST_DEBUG_OBJECT (self, "Failed to map classes tensor memory");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
|
||||
if (!get_guint32_at_index (tmeta->tensors[numdetect_index], &numdetect_map,
|
||||
if (!get_guint32_at_index (numdetect_tensor, &numdetect_map,
|
||||
0, &num_detections)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to get the number of detections");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
|
||||
GST_LOG_OBJECT (self, "Model claims %d detections", num_detections);
|
||||
GST_LOG_OBJECT (self, "Model claims %u detections", num_detections);
|
||||
num_detections = MIN (num_detections, scores_tensor->dims[1]);
|
||||
num_detections = MIN (num_detections, boxes_tensor->dims[1]);
|
||||
if (classes_tensor)
|
||||
num_detections = MIN (num_detections, classes_tensor->dims[1]);
|
||||
GST_LOG_OBJECT (self, "Model really has %u detections"
|
||||
" (%zu scores, %zu boxes, %zu classes)", num_detections,
|
||||
scores_tensor->dims[1], boxes_tensor->dims[1],
|
||||
classes_tensor ? classes_tensor->dims[1] : 0);
|
||||
|
||||
for (int i = 0; i < num_detections; i++) {
|
||||
float score;
|
||||
@ -484,25 +506,20 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX)
|
||||
GQuark label = 0;
|
||||
GstAnalyticsODMtd odmtd;
|
||||
|
||||
if (!get_float_at_index (tmeta->tensors[scores_index], &scores_map,
|
||||
i, &score))
|
||||
if (!get_float_at_index (scores_tensor, &scores_map, i, &score))
|
||||
continue;
|
||||
|
||||
GST_LOG_OBJECT (self, "Detection %u score is %f", i, score);
|
||||
if (score < self->score_threshold)
|
||||
continue;
|
||||
|
||||
if (!get_float_at_index (tmeta->tensors[boxes_index], &boxes_map,
|
||||
i * 4, &y))
|
||||
if (!get_float_at_index (boxes_tensor, &boxes_map, i * 4, &y))
|
||||
continue;
|
||||
if (!get_float_at_index (tmeta->tensors[boxes_index], &boxes_map,
|
||||
i * 4 + 1, &x))
|
||||
if (!get_float_at_index (boxes_tensor, &boxes_map, i * 4 + 1, &x))
|
||||
continue;
|
||||
if (!get_float_at_index (tmeta->tensors[boxes_index], &boxes_map,
|
||||
i * 4 + 2, &bheight))
|
||||
if (!get_float_at_index (boxes_tensor, &boxes_map, i * 4 + 2, &bheight))
|
||||
continue;
|
||||
if (!get_float_at_index (tmeta->tensors[boxes_index], &boxes_map,
|
||||
i * 4 + 3, &bwidth))
|
||||
if (!get_float_at_index (boxes_tensor, &boxes_map, i * 4 + 3, &bwidth))
|
||||
continue;
|
||||
|
||||
if (CLAMP (bwidth, 0, 1) * CLAMP (bheight, 0, 1) > self->size_threshold) {
|
||||
@ -513,8 +530,7 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX)
|
||||
}
|
||||
|
||||
if (self->labels && classes_map.memory &&
|
||||
get_guint32_at_index (tmeta->tensors[classes_index], &classes_map,
|
||||
i, &bclass)) {
|
||||
get_guint32_at_index (classes_tensor, &classes_map, i, &bclass)) {
|
||||
if (bclass < self->labels->len)
|
||||
label = g_array_index (self->labels, GQuark, bclass);
|
||||
}
|
||||
@ -536,13 +552,13 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX)
|
||||
cleanup:
|
||||
|
||||
if (numdetect_map.memory)
|
||||
gst_buffer_unmap (tmeta->tensors[numdetect_index]->data, &numdetect_map);
|
||||
gst_buffer_unmap (numdetect_tensor->data, &numdetect_map);
|
||||
if (classes_map.memory)
|
||||
gst_buffer_unmap (tmeta->tensors[classes_index]->data, &classes_map);
|
||||
gst_buffer_unmap (classes_tensor->data, &classes_map);
|
||||
if (scores_map.memory)
|
||||
gst_buffer_unmap (tmeta->tensors[scores_index]->data, &scores_map);
|
||||
gst_buffer_unmap (scores_tensor->data, &scores_map);
|
||||
if (boxes_map.memory)
|
||||
gst_buffer_unmap (tmeta->tensors[boxes_index]->data, &boxes_map);
|
||||
gst_buffer_unmap (boxes_tensor->data, &boxes_map);
|
||||
}
|
||||
|
||||
|
||||
@ -550,12 +566,15 @@ static gboolean
|
||||
gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
||||
{
|
||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
|
||||
GstTensorMeta *tmeta;
|
||||
GstAnalyticsRelationMeta *rmeta;
|
||||
const GstTensor *classes_tensor = NULL;
|
||||
const GstTensor *numdetect_tensor = NULL;
|
||||
const GstTensor *scores_tensor = NULL;
|
||||
const GstTensor *boxes_tensor = NULL;
|
||||
|
||||
// get all tensor metas
|
||||
tmeta = gst_ssd_object_detector_get_tensor_meta (self, buf);
|
||||
if (!tmeta) {
|
||||
if (!gst_ssd_object_detector_get_tensors (self, buf,
|
||||
&classes_tensor, &numdetect_tensor, &scores_tensor, &boxes_tensor)) {
|
||||
GST_WARNING_OBJECT (trans, "missing tensor meta");
|
||||
return TRUE;
|
||||
} else {
|
||||
@ -564,7 +583,8 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
||||
}
|
||||
|
||||
extract_bounding_boxes (self, self->video_info.width,
|
||||
self->video_info.height, rmeta, tmeta);
|
||||
self->video_info.height, rmeta, classes_tensor, numdetect_tensor,
|
||||
scores_tensor, boxes_tensor);
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user