onnx: Allow generic well-known names for tensors

This allows us to use the upstream version of the ssd_mobilenet model[1], and
starts setting us up to allow some tensor names by convention if we want to add
more decoders.

[1] https://github.com/onnx/models/tree/main/validated/vision/object_detection_segmentation/ssd-mobilenetv1

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/8117>
This commit is contained in:
Arun Raghavan 2024-12-09 13:39:16 -05:00 committed by GStreamer Marge Bot
parent 96e660e0d9
commit 5ffa6902c3

View File

@ -26,6 +26,12 @@
#define GST_CAT_DEFAULT onnx_inference_debug
/* FIXME: share this with tensordecoders, somehow? */
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
namespace GstOnnxNamespace
{
template < typename T >
@ -286,15 +292,38 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return false;
}
for (auto & name:outputNamesRaw) {
Ort::AllocatedStringPtr res =
metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
if (res)
{
GQuark quark = g_quark_from_string (res.get ());
outputIds.push_back (quark);
} else {
Ort::AllocatedStringPtr res =
metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
if (res)
{
GQuark quark = g_quark_from_string (res.get ());
outputIds.push_back (quark);
} else if (g_str_has_prefix (name, "detection_scores")) {
GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES);
GST_INFO_OBJECT(debug_parent,
"No custom metadata for key '%s', assuming %s",
name, GST_MODEL_OBJECT_DETECTOR_SCORES);
outputIds.push_back (quark);
} else if (g_str_has_prefix(name, "detection_boxes")) {
GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES);
GST_INFO_OBJECT(debug_parent,
"No custom metadata for key '%s', assuming %s",
name, GST_MODEL_OBJECT_DETECTOR_BOXES);
outputIds.push_back (quark);
} else if (g_str_has_prefix(name, "detection_classes")) {
GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES);
GST_INFO_OBJECT(debug_parent,
"No custom metadata for key '%s', assuming %s",
name, GST_MODEL_OBJECT_DETECTOR_CLASSES);
outputIds.push_back (quark);
} else if (g_str_has_prefix(name, "num_detections")) {
GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS);
GST_INFO_OBJECT(debug_parent,
"No custom metadata for key '%s', assuming %s",
name, GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS);
outputIds.push_back (quark);
} else {
GST_ERROR_OBJECT (debug_parent, "Failed to look up id for key %s", name);
return false;
}
}