From 5ffa6902c382c4ed6fc7c1718c1be4443c188b87 Mon Sep 17 00:00:00 2001
From: Arun Raghavan <arun@asymptotic.io>
Date: Mon, 9 Dec 2024 13:39:16 -0500
Subject: [PATCH] onnx: Allow generic well-known names for tensors

This allows us to use the upstream version of the ssd_mobilenet model[1], and
starts setting us up to allow some tensor names by convention if we want to add
more decoders.

[1] https://github.com/onnx/models/tree/main/validated/vision/object_detection_segmentation/ssd-mobilenetv1

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/8117>
---
 .../ext/onnx/gstonnxclient.cpp                | 45 +++++++++++++++----
 1 file changed, 37 insertions(+), 8 deletions(-)

diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp
index d632af3c6d..2b55eafd25 100644
--- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp
+++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp
@@ -26,6 +26,12 @@
 
 #define GST_CAT_DEFAULT onnx_inference_debug
 
+/* FIXME: share this with tensordecoders, somehow? */
+#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
+#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
+#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
+#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
+
 namespace GstOnnxNamespace
 {
   template < typename T >
@@ -286,15 +292,38 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
         return false;
       }
       for (auto & name:outputNamesRaw) {
-          Ort::AllocatedStringPtr res =
-              metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
-          if (res)
-            {
-              GQuark quark = g_quark_from_string (res.get ());
-              outputIds.push_back (quark);
-            } else {
+        Ort::AllocatedStringPtr res =
+          metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
+        if (res)
+        {
+          GQuark quark = g_quark_from_string (res.get ());
+          outputIds.push_back (quark);
+        } else if (g_str_has_prefix (name, "detection_scores")) {
+          GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES);
+          GST_INFO_OBJECT(debug_parent,
+              "No custom metadata for key '%s', assuming %s",
+              name, GST_MODEL_OBJECT_DETECTOR_SCORES);
+          outputIds.push_back (quark);
+        } else if (g_str_has_prefix(name, "detection_boxes")) {
+          GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES);
+          GST_INFO_OBJECT(debug_parent,
+              "No custom metadata for key '%s', assuming %s",
+              name, GST_MODEL_OBJECT_DETECTOR_BOXES);
+          outputIds.push_back (quark);
+        } else if (g_str_has_prefix(name, "detection_classes")) {
+          GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES);
+          GST_INFO_OBJECT(debug_parent,
+              "No custom metadata for key '%s', assuming %s",
+              name, GST_MODEL_OBJECT_DETECTOR_CLASSES);
+          outputIds.push_back (quark);
+        } else if (g_str_has_prefix(name, "num_detections")) {
+          GQuark quark = g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS);
+          GST_INFO_OBJECT(debug_parent,
+              "No custom metadata for key '%s', assuming %s",
+              name, GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS);
+          outputIds.push_back (quark);
+        } else {
           GST_ERROR_OBJECT (debug_parent, "Failed to look up id for key %s", name);
-
           return false;
         }
       }