fastsamtensordecoder: Set mask resolution based on model output
This commit is contained in:
parent
b7f964929c
commit
244dd01b22
@ -132,7 +132,7 @@ static void gst_fastsam_tensor_decoder_set_property (GObject * object,
|
|||||||
static void gst_fastsam_tensor_decoder_get_property (GObject * object,
|
static void gst_fastsam_tensor_decoder_get_property (GObject * object,
|
||||||
guint prop_id, GValue * value, GParamSpec * pspec);
|
guint prop_id, GValue * value, GParamSpec * pspec);
|
||||||
|
|
||||||
static void gst_fastsam_tensor_decoder_finalize (GObject * object);
|
static gboolean gst_fastsam_tensor_decoder_stop (GstBaseTransform * trans);
|
||||||
|
|
||||||
static GstFlowReturn gst_fastsam_tensor_decoder_transform_ip (GstBaseTransform *
|
static GstFlowReturn gst_fastsam_tensor_decoder_transform_ip (GstBaseTransform *
|
||||||
trans, GstBuffer * buf);
|
trans, GstBuffer * buf);
|
||||||
@ -160,9 +160,6 @@ gst_fastsam_tensor_decoder_class_init (GstFastSAMTensorDecoderClass * klass)
|
|||||||
gobject_class->set_property = gst_fastsam_tensor_decoder_set_property;
|
gobject_class->set_property = gst_fastsam_tensor_decoder_set_property;
|
||||||
gobject_class->get_property = gst_fastsam_tensor_decoder_get_property;
|
gobject_class->get_property = gst_fastsam_tensor_decoder_get_property;
|
||||||
|
|
||||||
/* Set GObject vmethod finalize */
|
|
||||||
gobject_class->finalize = gst_fastsam_tensor_decoder_finalize;
|
|
||||||
|
|
||||||
/* Define GstFastSAMTensorDecoder properties using GObject properties
|
/* Define GstFastSAMTensorDecoder properties using GObject properties
|
||||||
* interface.*/
|
* interface.*/
|
||||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||||
@ -254,6 +251,10 @@ gst_fastsam_tensor_decoder_class_init (GstFastSAMTensorDecoderClass * klass)
|
|||||||
basetransform_class->set_caps =
|
basetransform_class->set_caps =
|
||||||
GST_DEBUG_FUNCPTR (gst_fastsam_tensor_decoder_set_caps);
|
GST_DEBUG_FUNCPTR (gst_fastsam_tensor_decoder_set_caps);
|
||||||
|
|
||||||
|
/* Set GObject vmethod finalize */
|
||||||
|
basetransform_class->stop = gst_fastsam_tensor_decoder_stop;
|
||||||
|
|
||||||
|
|
||||||
/* Calculate the class id placeholder (also a quark) that will be set on all
|
/* Calculate the class id placeholder (also a quark) that will be set on all
|
||||||
* OD analytics-meta. */
|
* OD analytics-meta. */
|
||||||
OOI_CLASS_ID = g_quark_from_static_string ("FastSAM-None");
|
OOI_CLASS_ID = g_quark_from_static_string ("FastSAM-None");
|
||||||
@ -277,32 +278,31 @@ gst_fastsam_tensor_decoder_init (GstFastSAMTensorDecoder * self)
|
|||||||
self->max_detection = DEFAULT_MAX_DETECTION;
|
self->max_detection = DEFAULT_MAX_DETECTION;
|
||||||
self->sel_candidates = NULL;
|
self->sel_candidates = NULL;
|
||||||
self->selected = NULL;
|
self->selected = NULL;
|
||||||
self->mask_w = 256;
|
self->mask_w = 0;
|
||||||
self->mask_h = 256;
|
self->mask_h = 0;
|
||||||
self->mask_length = self->mask_w * self->mask_h;
|
self->mask_length = 0;
|
||||||
memset (&self->mask_roi, 0, sizeof (BBox));
|
memset (&self->mask_roi, 0, sizeof (BBox));
|
||||||
self->mask_pool = NULL;
|
self->mask_pool = NULL;
|
||||||
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), FALSE);
|
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), FALSE);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static gboolean
|
||||||
gst_fastsam_tensor_decoder_finalize (GObject * object)
|
gst_fastsam_tensor_decoder_stop (GstBaseTransform * trans)
|
||||||
{
|
{
|
||||||
GstFastSAMTensorDecoder *self = GST_FASTSAM_TENSOR_DECODER (object);
|
GstFastSAMTensorDecoder *self = GST_FASTSAM_TENSOR_DECODER (trans);
|
||||||
|
|
||||||
if (self->sel_candidates) {
|
self->mask_w = 0;
|
||||||
g_ptr_array_unref (g_steal_pointer (&self->sel_candidates));
|
self->mask_h = 0;
|
||||||
}
|
self->mask_length = 0;
|
||||||
|
|
||||||
if (self->selected) {
|
g_clear_pointer (&self->sel_candidates, g_ptr_array_unref);
|
||||||
g_ptr_array_unref (g_steal_pointer (&self->selected));
|
g_clear_pointer (&self->selected, g_ptr_array_unref);
|
||||||
}
|
if (self->mask_pool)
|
||||||
|
gst_buffer_pool_set_active (self->mask_pool, FALSE);
|
||||||
|
|
||||||
if (self->mask_pool) {
|
g_clear_object (&self->mask_pool);
|
||||||
gst_object_unref (self->mask_pool);
|
|
||||||
}
|
|
||||||
|
|
||||||
G_OBJECT_CLASS (gst_fastsam_tensor_decoder_parent_class)->finalize (object);
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
@ -458,46 +458,6 @@ gst_fastsam_tensor_decoder_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
|||||||
return FALSE;
|
return FALSE;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* The masks need to be cropped to fit the SAR of the image. */
|
|
||||||
/* TODO: We're reconstructing the transformation that was done on the
|
|
||||||
* original image based on the assumption that the complete image without
|
|
||||||
* deformation would be analyzed. This assumption is not alway true and
|
|
||||||
* we should try to find a way to convey this transformation information
|
|
||||||
* and retrieve from here to know the transformation that need to be done
|
|
||||||
* on the mask.*/
|
|
||||||
|
|
||||||
if (self->video_info.width > self->video_info.height) {
|
|
||||||
self->bb2mask_gain = ((gfloat) self->mask_w) / self->video_info.width;
|
|
||||||
self->mask_roi.x = 0;
|
|
||||||
self->mask_roi.w = self->mask_w;
|
|
||||||
self->mask_roi.h = ((gfloat) self->bb2mask_gain) * self->video_info.height;
|
|
||||||
self->mask_roi.y = (self->mask_h - self->mask_roi.h) / 2;
|
|
||||||
} else {
|
|
||||||
self->bb2mask_gain = ((gfloat) self->mask_h) / self->video_info.height;
|
|
||||||
self->mask_roi.y = 0;
|
|
||||||
self->mask_roi.h = self->mask_h;
|
|
||||||
self->mask_roi.w = self->bb2mask_gain * self->video_info.width;
|
|
||||||
self->mask_roi.x = (self->mask_w - self->mask_roi.w) / 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (self->mask_pool == NULL) {
|
|
||||||
GstVideoInfo minfo;
|
|
||||||
GstCaps *caps;
|
|
||||||
gst_video_info_init (&minfo);
|
|
||||||
gst_video_info_set_format (&minfo, GST_VIDEO_FORMAT_GRAY8, 256, 256);
|
|
||||||
caps = gst_video_info_to_caps (&minfo);
|
|
||||||
self->mask_pool = gst_video_buffer_pool_new ();
|
|
||||||
GstStructure *config = gst_buffer_pool_get_config (self->mask_pool);
|
|
||||||
gst_buffer_pool_config_set_params (config, caps, self->mask_length, 0, 0);
|
|
||||||
gst_buffer_pool_config_add_option (config,
|
|
||||||
GST_BUFFER_POOL_OPTION_VIDEO_META);
|
|
||||||
g_return_val_if_fail (gst_buffer_pool_set_config (self->mask_pool, config),
|
|
||||||
FALSE);
|
|
||||||
g_return_val_if_fail (gst_buffer_pool_set_active (self->mask_pool, TRUE),
|
|
||||||
FALSE);
|
|
||||||
gst_caps_unref (caps);
|
|
||||||
}
|
|
||||||
|
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -522,10 +482,10 @@ gst_fastsam_tensor_decoder_transform_ip (GstBaseTransform * trans,
|
|||||||
&logits_tensor))
|
&logits_tensor))
|
||||||
return GST_FLOW_OK;
|
return GST_FLOW_OK;
|
||||||
|
|
||||||
if (masks_tensor->num_dims < 3) {
|
if (masks_tensor->num_dims != 3) {
|
||||||
GST_ELEMENT_ERROR (self, STREAM, DECODE, (NULL),
|
GST_ELEMENT_ERROR (self, STREAM, DECODE, (NULL),
|
||||||
("Masks tensor must have at least 3 dimensions,"
|
("Masks tensor must have 3 dimensions but has %zu",
|
||||||
"but only has %zu", masks_tensor->num_dims));
|
masks_tensor->num_dims));
|
||||||
return GST_FLOW_ERROR;
|
return GST_FLOW_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -603,15 +563,10 @@ gst_fastsam_tensor_decoder_transform_ip (GstBaseTransform * trans,
|
|||||||
* retrieve an analytics-relation-meta if it exist or create one if it
|
* retrieve an analytics-relation-meta if it exist or create one if it
|
||||||
* does not exist. */
|
* does not exist. */
|
||||||
rmeta = gst_buffer_add_analytics_relation_meta_full (buf, &rmeta_init_params);
|
rmeta = gst_buffer_add_analytics_relation_meta_full (buf, &rmeta_init_params);
|
||||||
g_return_val_if_fail (rmeta != NULL, GST_FLOW_ERROR);
|
g_assert (rmeta != NULL);
|
||||||
|
|
||||||
/* Decode masks_tensor and attach the information in a structured way
|
/* Decode masks_tensor and attach the information in a structured way
|
||||||
* to rmeta.
|
* to rmeta. */
|
||||||
* TODO: I think we need to send both tensors masks and logits
|
|
||||||
* to gst_fastsam_tensor_decoder_decode_masks_f32 since both are
|
|
||||||
* required simultanously to extract the segmentation. If this is the case
|
|
||||||
* we probably should rename gst_fastsam_tensor_decoder_decode_masks_f32 to
|
|
||||||
* gst_fastsam_tensor_decoder_decode_f32. */
|
|
||||||
gst_fastsam_tensor_decoder_decode_masks_f32 (self, masks_tensor,
|
gst_fastsam_tensor_decoder_decode_masks_f32 (self, masks_tensor,
|
||||||
logits_tensor, rmeta);
|
logits_tensor, rmeta);
|
||||||
|
|
||||||
@ -802,8 +757,8 @@ gst_fastsam_tensor_decoder_decode_masks_f32 (GstFastSAMTensorDecoder * self,
|
|||||||
self->selected = selected;
|
self->selected = selected;
|
||||||
} else {
|
} else {
|
||||||
/* Reset lengths when we re-use arrays */
|
/* Reset lengths when we re-use arrays */
|
||||||
sel_candidates->len = 0;
|
g_ptr_array_set_size (sel_candidates, 0);
|
||||||
selected->len = 0;
|
g_ptr_array_set_size (selected, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* masks_tensor->dims[2] contain the number of candidates. Let's call the
|
/* masks_tensor->dims[2] contain the number of candidates. Let's call the
|
||||||
@ -941,8 +896,8 @@ gst_fastsam_tensor_decoder_decode_masks_f32 (GstFastSAMTensorDecoder * self,
|
|||||||
for (gint mx = bb_mask.x; mx < MX_MAX; mx++, i++) {
|
for (gint mx = bb_mask.x; mx < MX_MAX; mx++, i++) {
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
j = my * self->mask_w + mx;
|
j = my * self->mask_w + mx;
|
||||||
for (gint k = 0; k < 32; ++k) {
|
for (gsize k = 0; k < logits_tensor->dims[1]; ++k) {
|
||||||
GST_TRACE_OBJECT (self, "protos data at (%d, %d) is %f", j, k,
|
GST_TRACE_OBJECT (self, "protos data at (%d, %zu) is %f", j, k,
|
||||||
data_logits[k * self->mask_length + j]);
|
data_logits[k * self->mask_length + j]);
|
||||||
sum +=
|
sum +=
|
||||||
MASK_X (candidate, k) * data_logits[k * self->mask_length + j];
|
MASK_X (candidate, k) * data_logits[k * self->mask_length + j];
|
||||||
|
Loading…
x
Reference in New Issue
Block a user