classifiertensordecoder: Use utility functions to get tensors

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/9419>
This commit is contained in:
Olivier Crête 2025-07-17 17:14:06 -04:00 committed by GStreamer Marge Bot
parent d1483d6c55
commit d1b00839c3

View File

@ -408,66 +408,43 @@ gst_classifier_tensor_decoder_decode (GstClassifierTensorDecoder * self,
gint max_idx = -1;
const GstTensor *tensor;
GstAnalyticsClsMtd cls_mtd;
const gsize DIMS[] = { 1, G_MAXSIZE };
tensor = gst_tensor_meta_get_by_id (tmeta,
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION));
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
if (tensor == NULL)
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);
if (tensor == NULL)
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
if (tensor == NULL)
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);
if (tensor->dims_order != GST_TENSOR_DIM_ORDER_ROW_MAJOR) {
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, NOT_IMPLEMENTED,
("Only row-major tensor are supported"),
("this element only support tensor with dims_order set to "
"GST_TENSOR_DIM_ORDER_ROW_MAJOR"));
return GST_FLOW_ERROR;
}
if (tensor->num_dims != 1 && tensor->num_dims != 2) {
if (tensor == NULL) {
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
("Only tenson of 1 dimension is supported."),
("tensor dimension must be 1xm or m."));
(NULL), ("Could not find classification tensor"));
return GST_FLOW_ERROR;
}
if (tensor->data_type != GST_TENSOR_DATA_TYPE_FLOAT32 &&
tensor->data_type != GST_TENSOR_DATA_TYPE_UINT8) {
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, NOT_IMPLEMENTED,
("Only data-type UINT8 and FLOAT32 support is implemented"),
("Please implement."));
len = tensor->dims[tensor->num_dims - 1];
return GST_FLOW_ERROR;
if (len != self->class_quark->len) {
GST_WARNING_OBJECT (self, "Labels file has size %zu, but the tensor has"
" %u entries, it is probably not the right labels file",
len, self->class_quark->len);
len = MIN (len, self->class_quark->len);
}
if (tensor->num_dims == 1) {
if (tensor->dims[0] == 0) {
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"),
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"));
return GST_FLOW_ERROR;
}
len = tensor->dims[0];
} else {
if (tensor->dims[0] != 1) {
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, NOT_IMPLEMENTED,
("Batch not implemented"),
("Batch not implemented, please implement"));
return GST_FLOW_ERROR;
}
if (tensor->dims[1] == 0) {
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"),
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"));
return GST_FLOW_ERROR;
}
len = tensor->dims[1];
}
g_return_val_if_fail (len == self->class_quark->len, GST_FLOW_ERROR);
if (!gst_buffer_map (tensor->data, &map_info, GST_MAP_READ)) {
GST_ERROR_OBJECT (self, "Failed to map tensor data");
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
("Failed to map tensor data"));
return GST_FLOW_ERROR;
}
GST_TRACE_OBJECT (self, "Tensor shape dims %zu", tensor->num_dims);