classifiertensordecoder: Use utility functions to get tensors
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/9419>
This commit is contained in:
parent
d1483d6c55
commit
d1b00839c3
@ -408,66 +408,43 @@ gst_classifier_tensor_decoder_decode (GstClassifierTensorDecoder * self,
|
|||||||
gint max_idx = -1;
|
gint max_idx = -1;
|
||||||
const GstTensor *tensor;
|
const GstTensor *tensor;
|
||||||
GstAnalyticsClsMtd cls_mtd;
|
GstAnalyticsClsMtd cls_mtd;
|
||||||
|
const gsize DIMS[] = { 1, G_MAXSIZE };
|
||||||
|
|
||||||
tensor = gst_tensor_meta_get_by_id (tmeta,
|
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||||
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION));
|
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
||||||
|
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
|
||||||
|
if (tensor == NULL)
|
||||||
|
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||||
|
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
||||||
|
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);
|
||||||
|
if (tensor == NULL)
|
||||||
|
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||||
|
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
||||||
|
GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
|
||||||
|
if (tensor == NULL)
|
||||||
|
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||||
|
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
||||||
|
GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);
|
||||||
|
|
||||||
if (tensor->dims_order != GST_TENSOR_DIM_ORDER_ROW_MAJOR) {
|
if (tensor == NULL) {
|
||||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, NOT_IMPLEMENTED,
|
|
||||||
("Only row-major tensor are supported"),
|
|
||||||
("this element only support tensor with dims_order set to "
|
|
||||||
"GST_TENSOR_DIM_ORDER_ROW_MAJOR"));
|
|
||||||
|
|
||||||
return GST_FLOW_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tensor->num_dims != 1 && tensor->num_dims != 2) {
|
|
||||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
|
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
|
||||||
("Only tenson of 1 dimension is supported."),
|
(NULL), ("Could not find classification tensor"));
|
||||||
("tensor dimension must be 1xm or m."));
|
|
||||||
|
|
||||||
return GST_FLOW_ERROR;
|
return GST_FLOW_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tensor->data_type != GST_TENSOR_DATA_TYPE_FLOAT32 &&
|
len = tensor->dims[tensor->num_dims - 1];
|
||||||
tensor->data_type != GST_TENSOR_DATA_TYPE_UINT8) {
|
|
||||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, NOT_IMPLEMENTED,
|
|
||||||
("Only data-type UINT8 and FLOAT32 support is implemented"),
|
|
||||||
("Please implement."));
|
|
||||||
|
|
||||||
return GST_FLOW_ERROR;
|
if (len != self->class_quark->len) {
|
||||||
|
GST_WARNING_OBJECT (self, "Labels file has size %zu, but the tensor has"
|
||||||
|
" %u entries, it is probably not the right labels file",
|
||||||
|
len, self->class_quark->len);
|
||||||
|
len = MIN (len, self->class_quark->len);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tensor->num_dims == 1) {
|
|
||||||
if (tensor->dims[0] == 0) {
|
|
||||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
|
|
||||||
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"),
|
|
||||||
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"));
|
|
||||||
return GST_FLOW_ERROR;
|
|
||||||
}
|
|
||||||
len = tensor->dims[0];
|
|
||||||
} else {
|
|
||||||
if (tensor->dims[0] != 1) {
|
|
||||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, NOT_IMPLEMENTED,
|
|
||||||
("Batch not implemented"),
|
|
||||||
("Batch not implemented, please implement"));
|
|
||||||
return GST_FLOW_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tensor->dims[1] == 0) {
|
|
||||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
|
|
||||||
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"),
|
|
||||||
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"));
|
|
||||||
return GST_FLOW_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
len = tensor->dims[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
g_return_val_if_fail (len == self->class_quark->len, GST_FLOW_ERROR);
|
|
||||||
|
|
||||||
if (!gst_buffer_map (tensor->data, &map_info, GST_MAP_READ)) {
|
if (!gst_buffer_map (tensor->data, &map_info, GST_MAP_READ)) {
|
||||||
GST_ERROR_OBJECT (self, "Failed to map tensor data");
|
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
|
||||||
|
("Failed to map tensor data"));
|
||||||
|
return GST_FLOW_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
GST_TRACE_OBJECT (self, "Tensor shape dims %zu", tensor->num_dims);
|
GST_TRACE_OBJECT (self, "Tensor shape dims %zu", tensor->num_dims);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user