classifiertensordecoder: Use utility functions to get tensors
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/9419>
This commit is contained in:
parent
d1483d6c55
commit
d1b00839c3
@ -408,66 +408,43 @@ gst_classifier_tensor_decoder_decode (GstClassifierTensorDecoder * self,
|
||||
gint max_idx = -1;
|
||||
const GstTensor *tensor;
|
||||
GstAnalyticsClsMtd cls_mtd;
|
||||
const gsize DIMS[] = { 1, G_MAXSIZE };
|
||||
|
||||
tensor = gst_tensor_meta_get_by_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION));
|
||||
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
||||
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
|
||||
if (tensor == NULL)
|
||||
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
||||
GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);
|
||||
if (tensor == NULL)
|
||||
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
||||
GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
|
||||
if (tensor == NULL)
|
||||
tensor = gst_tensor_meta_get_typed_tensor (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_STD_IMAGE_CLASSIFICATION),
|
||||
GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);
|
||||
|
||||
if (tensor->dims_order != GST_TENSOR_DIM_ORDER_ROW_MAJOR) {
|
||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, NOT_IMPLEMENTED,
|
||||
("Only row-major tensor are supported"),
|
||||
("this element only support tensor with dims_order set to "
|
||||
"GST_TENSOR_DIM_ORDER_ROW_MAJOR"));
|
||||
|
||||
return GST_FLOW_ERROR;
|
||||
}
|
||||
|
||||
if (tensor->num_dims != 1 && tensor->num_dims != 2) {
|
||||
if (tensor == NULL) {
|
||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
|
||||
("Only tenson of 1 dimension is supported."),
|
||||
("tensor dimension must be 1xm or m."));
|
||||
|
||||
(NULL), ("Could not find classification tensor"));
|
||||
return GST_FLOW_ERROR;
|
||||
}
|
||||
|
||||
if (tensor->data_type != GST_TENSOR_DATA_TYPE_FLOAT32 &&
|
||||
tensor->data_type != GST_TENSOR_DATA_TYPE_UINT8) {
|
||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, NOT_IMPLEMENTED,
|
||||
("Only data-type UINT8 and FLOAT32 support is implemented"),
|
||||
("Please implement."));
|
||||
len = tensor->dims[tensor->num_dims - 1];
|
||||
|
||||
return GST_FLOW_ERROR;
|
||||
if (len != self->class_quark->len) {
|
||||
GST_WARNING_OBJECT (self, "Labels file has size %zu, but the tensor has"
|
||||
" %u entries, it is probably not the right labels file",
|
||||
len, self->class_quark->len);
|
||||
len = MIN (len, self->class_quark->len);
|
||||
}
|
||||
|
||||
if (tensor->num_dims == 1) {
|
||||
if (tensor->dims[0] == 0) {
|
||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
|
||||
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"),
|
||||
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"));
|
||||
return GST_FLOW_ERROR;
|
||||
}
|
||||
len = tensor->dims[0];
|
||||
} else {
|
||||
if (tensor->dims[0] != 1) {
|
||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, NOT_IMPLEMENTED,
|
||||
("Batch not implemented"),
|
||||
("Batch not implemented, please implement"));
|
||||
return GST_FLOW_ERROR;
|
||||
}
|
||||
|
||||
if (tensor->dims[1] == 0) {
|
||||
GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
|
||||
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"),
|
||||
("A tensor without content (dims[0] ==0, num_dims=1) can't be used"));
|
||||
return GST_FLOW_ERROR;
|
||||
}
|
||||
|
||||
len = tensor->dims[1];
|
||||
}
|
||||
|
||||
g_return_val_if_fail (len == self->class_quark->len, GST_FLOW_ERROR);
|
||||
|
||||
if (!gst_buffer_map (tensor->data, &map_info, GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map tensor data");
|
||||
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
|
||||
("Failed to map tensor data"));
|
||||
return GST_FLOW_ERROR;
|
||||
}
|
||||
|
||||
GST_TRACE_OBJECT (self, "Tensor shape dims %zu", tensor->num_dims);
|
||||
|
Loading…
x
Reference in New Issue
Block a user