fastsamtensordecoder: Set mask resolution based on model output

This commit is contained in:
Olivier Crête 2025-01-06 13:29:28 -06:00 committed by Elias Rosendahl
parent b7f964929c
commit 244dd01b22

View File

@ -132,7 +132,7 @@ static void gst_fastsam_tensor_decoder_set_property (GObject * object,
static void gst_fastsam_tensor_decoder_get_property (GObject * object, static void gst_fastsam_tensor_decoder_get_property (GObject * object,
guint prop_id, GValue * value, GParamSpec * pspec); guint prop_id, GValue * value, GParamSpec * pspec);
static void gst_fastsam_tensor_decoder_finalize (GObject * object); static gboolean gst_fastsam_tensor_decoder_stop (GstBaseTransform * trans);
static GstFlowReturn gst_fastsam_tensor_decoder_transform_ip (GstBaseTransform * static GstFlowReturn gst_fastsam_tensor_decoder_transform_ip (GstBaseTransform *
trans, GstBuffer * buf); trans, GstBuffer * buf);
@ -160,9 +160,6 @@ gst_fastsam_tensor_decoder_class_init (GstFastSAMTensorDecoderClass * klass)
gobject_class->set_property = gst_fastsam_tensor_decoder_set_property; gobject_class->set_property = gst_fastsam_tensor_decoder_set_property;
gobject_class->get_property = gst_fastsam_tensor_decoder_get_property; gobject_class->get_property = gst_fastsam_tensor_decoder_get_property;
/* Set GObject vmethod finalize */
gobject_class->finalize = gst_fastsam_tensor_decoder_finalize;
/* Define GstFastSAMTensorDecoder properties using GObject properties /* Define GstFastSAMTensorDecoder properties using GObject properties
* interface.*/ * interface.*/
g_object_class_install_property (G_OBJECT_CLASS (klass), g_object_class_install_property (G_OBJECT_CLASS (klass),
@ -254,6 +251,10 @@ gst_fastsam_tensor_decoder_class_init (GstFastSAMTensorDecoderClass * klass)
basetransform_class->set_caps = basetransform_class->set_caps =
GST_DEBUG_FUNCPTR (gst_fastsam_tensor_decoder_set_caps); GST_DEBUG_FUNCPTR (gst_fastsam_tensor_decoder_set_caps);
/* Set GObject vmethod finalize */
basetransform_class->stop = gst_fastsam_tensor_decoder_stop;
/* Calculate the class id placeholder (also a quark) that will be set on all /* Calculate the class id placeholder (also a quark) that will be set on all
* OD analytics-meta. */ * OD analytics-meta. */
OOI_CLASS_ID = g_quark_from_static_string ("FastSAM-None"); OOI_CLASS_ID = g_quark_from_static_string ("FastSAM-None");
@ -277,32 +278,31 @@ gst_fastsam_tensor_decoder_init (GstFastSAMTensorDecoder * self)
self->max_detection = DEFAULT_MAX_DETECTION; self->max_detection = DEFAULT_MAX_DETECTION;
self->sel_candidates = NULL; self->sel_candidates = NULL;
self->selected = NULL; self->selected = NULL;
self->mask_w = 256; self->mask_w = 0;
self->mask_h = 256; self->mask_h = 0;
self->mask_length = self->mask_w * self->mask_h; self->mask_length = 0;
memset (&self->mask_roi, 0, sizeof (BBox)); memset (&self->mask_roi, 0, sizeof (BBox));
self->mask_pool = NULL; self->mask_pool = NULL;
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), FALSE); gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), FALSE);
} }
static void static gboolean
gst_fastsam_tensor_decoder_finalize (GObject * object) gst_fastsam_tensor_decoder_stop (GstBaseTransform * trans)
{ {
GstFastSAMTensorDecoder *self = GST_FASTSAM_TENSOR_DECODER (object); GstFastSAMTensorDecoder *self = GST_FASTSAM_TENSOR_DECODER (trans);
if (self->sel_candidates) { self->mask_w = 0;
g_ptr_array_unref (g_steal_pointer (&self->sel_candidates)); self->mask_h = 0;
} self->mask_length = 0;
if (self->selected) { g_clear_pointer (&self->sel_candidates, g_ptr_array_unref);
g_ptr_array_unref (g_steal_pointer (&self->selected)); g_clear_pointer (&self->selected, g_ptr_array_unref);
} if (self->mask_pool)
gst_buffer_pool_set_active (self->mask_pool, FALSE);
if (self->mask_pool) { g_clear_object (&self->mask_pool);
gst_object_unref (self->mask_pool);
}
G_OBJECT_CLASS (gst_fastsam_tensor_decoder_parent_class)->finalize (object); return TRUE;
} }
static void static void
@ -458,46 +458,6 @@ gst_fastsam_tensor_decoder_set_caps (GstBaseTransform * trans, GstCaps * incaps,
return FALSE; return FALSE;
} }
/* The masks need to be cropped to fit the SAR of the image. */
/* TODO: We're reconstructing the transformation that was done on the
* original image based on the assumption that the complete image without
* deformation would be analyzed. This assumption is not alway true and
* we should try to find a way to convey this transformation information
* and retrieve from here to know the transformation that need to be done
* on the mask.*/
if (self->video_info.width > self->video_info.height) {
self->bb2mask_gain = ((gfloat) self->mask_w) / self->video_info.width;
self->mask_roi.x = 0;
self->mask_roi.w = self->mask_w;
self->mask_roi.h = ((gfloat) self->bb2mask_gain) * self->video_info.height;
self->mask_roi.y = (self->mask_h - self->mask_roi.h) / 2;
} else {
self->bb2mask_gain = ((gfloat) self->mask_h) / self->video_info.height;
self->mask_roi.y = 0;
self->mask_roi.h = self->mask_h;
self->mask_roi.w = self->bb2mask_gain * self->video_info.width;
self->mask_roi.x = (self->mask_w - self->mask_roi.w) / 2;
}
if (self->mask_pool == NULL) {
GstVideoInfo minfo;
GstCaps *caps;
gst_video_info_init (&minfo);
gst_video_info_set_format (&minfo, GST_VIDEO_FORMAT_GRAY8, 256, 256);
caps = gst_video_info_to_caps (&minfo);
self->mask_pool = gst_video_buffer_pool_new ();
GstStructure *config = gst_buffer_pool_get_config (self->mask_pool);
gst_buffer_pool_config_set_params (config, caps, self->mask_length, 0, 0);
gst_buffer_pool_config_add_option (config,
GST_BUFFER_POOL_OPTION_VIDEO_META);
g_return_val_if_fail (gst_buffer_pool_set_config (self->mask_pool, config),
FALSE);
g_return_val_if_fail (gst_buffer_pool_set_active (self->mask_pool, TRUE),
FALSE);
gst_caps_unref (caps);
}
return TRUE; return TRUE;
} }
@ -522,10 +482,10 @@ gst_fastsam_tensor_decoder_transform_ip (GstBaseTransform * trans,
&logits_tensor)) &logits_tensor))
return GST_FLOW_OK; return GST_FLOW_OK;
if (masks_tensor->num_dims < 3) { if (masks_tensor->num_dims != 3) {
GST_ELEMENT_ERROR (self, STREAM, DECODE, (NULL), GST_ELEMENT_ERROR (self, STREAM, DECODE, (NULL),
("Masks tensor must have at least 3 dimensions," ("Masks tensor must have 3 dimensions but has %zu",
"but only has %zu", masks_tensor->num_dims)); masks_tensor->num_dims));
return GST_FLOW_ERROR; return GST_FLOW_ERROR;
} }
@ -603,15 +563,10 @@ gst_fastsam_tensor_decoder_transform_ip (GstBaseTransform * trans,
* retrieve an analytics-relation-meta if it exist or create one if it * retrieve an analytics-relation-meta if it exist or create one if it
* does not exist. */ * does not exist. */
rmeta = gst_buffer_add_analytics_relation_meta_full (buf, &rmeta_init_params); rmeta = gst_buffer_add_analytics_relation_meta_full (buf, &rmeta_init_params);
g_return_val_if_fail (rmeta != NULL, GST_FLOW_ERROR); g_assert (rmeta != NULL);
/* Decode masks_tensor and attach the information in a structured way /* Decode masks_tensor and attach the information in a structured way
* to rmeta. * to rmeta. */
* TODO: I think we need to send both tensors masks and logits
* to gst_fastsam_tensor_decoder_decode_masks_f32 since both are
* required simultanously to extract the segmentation. If this is the case
* we probably should rename gst_fastsam_tensor_decoder_decode_masks_f32 to
* gst_fastsam_tensor_decoder_decode_f32. */
gst_fastsam_tensor_decoder_decode_masks_f32 (self, masks_tensor, gst_fastsam_tensor_decoder_decode_masks_f32 (self, masks_tensor,
logits_tensor, rmeta); logits_tensor, rmeta);
@ -802,8 +757,8 @@ gst_fastsam_tensor_decoder_decode_masks_f32 (GstFastSAMTensorDecoder * self,
self->selected = selected; self->selected = selected;
} else { } else {
/* Reset lengths when we re-use arrays */ /* Reset lengths when we re-use arrays */
sel_candidates->len = 0; g_ptr_array_set_size (sel_candidates, 0);
selected->len = 0; g_ptr_array_set_size (selected, 0);
} }
/* masks_tensor->dims[2] contain the number of candidates. Let's call the /* masks_tensor->dims[2] contain the number of candidates. Let's call the
@ -941,8 +896,8 @@ gst_fastsam_tensor_decoder_decode_masks_f32 (GstFastSAMTensorDecoder * self,
for (gint mx = bb_mask.x; mx < MX_MAX; mx++, i++) { for (gint mx = bb_mask.x; mx < MX_MAX; mx++, i++) {
float sum = 0.0f; float sum = 0.0f;
j = my * self->mask_w + mx; j = my * self->mask_w + mx;
for (gint k = 0; k < 32; ++k) { for (gsize k = 0; k < logits_tensor->dims[1]; ++k) {
GST_TRACE_OBJECT (self, "protos data at (%d, %d) is %f", j, k, GST_TRACE_OBJECT (self, "protos data at (%d, %zu) is %f", j, k,
data_logits[k * self->mask_length + j]); data_logits[k * self->mask_length + j]);
sum += sum +=
MASK_X (candidate, k) * data_logits[k * self->mask_length + j]; MASK_X (candidate, k) * data_logits[k * self->mask_length + j];