From 16e94b7fc35b0255aa49b23b5c3cf4234ccbaea7 Mon Sep 17 00:00:00 2001
From: Seungha Yang <seungha@centricular.com>
Date: Mon, 16 Sep 2024 23:24:30 +0900
Subject: [PATCH] nvcodec: Add support CUDA to D3D12 memory copy

Adding CUDA -> D3D12 memory copy method to GstCudaD3D12Interop

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/7529>
---
 .../sys/nvcodec/gstcudainterop_d3d12.cpp      | 190 ++++++++++++++++--
 .../sys/nvcodec/gstcudainterop_d3d12.h        |   8 +-
 .../sys/nvcodec/gstnvencoder.cpp              |   2 +-
 3 files changed, 179 insertions(+), 21 deletions(-)

diff --git a/subprojects/gst-plugins-bad/sys/nvcodec/gstcudainterop_d3d12.cpp b/subprojects/gst-plugins-bad/sys/nvcodec/gstcudainterop_d3d12.cpp
index 262b213f20..0275459e45 100644
--- a/subprojects/gst-plugins-bad/sys/nvcodec/gstcudainterop_d3d12.cpp
+++ b/subprojects/gst-plugins-bad/sys/nvcodec/gstcudainterop_d3d12.cpp
@@ -265,7 +265,7 @@ gst_cuda_d3d12_interop_init (GstCudaD3D12Interop * self)
 
 GstCudaD3D12Interop *
 gst_cuda_d3d12_interop_new (GstCudaContext * context, GstD3D12Device * device,
-    const GstVideoInfo * info)
+    const GstVideoInfo * info, gboolean is_uploader)
 {
   gint64 cuda_luid = 0;
   gint64 d3d_luid = 0;
@@ -320,11 +320,19 @@ gst_cuda_d3d12_interop_new (GstCudaContext * context, GstD3D12Device * device,
   auto device_handle = gst_d3d12_device_get_device_handle (device);
   priv->alloc_info = device_handle->GetResourceAllocationInfo (0, 1, &desc);
 
-  priv->in_fence = gst_d3d12_device_get_fence_handle (device,
-      D3D12_COMMAND_LIST_TYPE_COMPUTE);
+  HRESULT hr;
+  if (is_uploader) {
+    priv->in_fence = gst_d3d12_device_get_fence_handle (device,
+        D3D12_COMMAND_LIST_TYPE_COMPUTE);
+    hr = device_handle->CreateFence (0, D3D12_FENCE_FLAG_SHARED,
+        IID_PPV_ARGS (&priv->out_fence));
+  } else {
+    priv->out_fence = gst_d3d12_device_get_fence_handle (device,
+        D3D12_COMMAND_LIST_TYPE_COMPUTE);
+    hr = device_handle->CreateFence (0, D3D12_FENCE_FLAG_SHARED,
+        IID_PPV_ARGS (&priv->in_fence));
+  }
 
-  auto hr = device_handle->CreateFence (0, D3D12_FENCE_FLAG_SHARED,
-      IID_PPV_ARGS (&priv->out_fence));
   if (!gst_d3d12_result (hr, device)) {
     gst_object_unref (self);
     return nullptr;
@@ -353,22 +361,24 @@ gst_cuda_d3d12_interop_new (GstCudaContext * context, GstD3D12Device * device,
     return nullptr;
   }
 
-  hr = device_handle->CreateSharedHandle (priv->out_fence.Get (), nullptr,
-      GENERIC_ALL, nullptr, &nt_handle);
-  if (!gst_d3d12_result (hr, device)) {
+  if (is_uploader) {
+    hr = device_handle->CreateSharedHandle (priv->out_fence.Get (), nullptr,
+        GENERIC_ALL, nullptr, &nt_handle);
+    if (!gst_d3d12_result (hr, device)) {
+      gst_cuda_context_pop (nullptr);
+      gst_object_unref (self);
+      return nullptr;
+    }
+
+    sem_desc.handle.win32.handle = nt_handle;
+    cuda_ret = CuImportExternalSemaphore (&priv->out_sem, &sem_desc);
+    CloseHandle (nt_handle);
     gst_cuda_context_pop (nullptr);
-    gst_object_unref (self);
-    return nullptr;
-  }
 
-  sem_desc.handle.win32.handle = nt_handle;
-  cuda_ret = CuImportExternalSemaphore (&priv->out_sem, &sem_desc);
-  CloseHandle (nt_handle);
-  gst_cuda_context_pop (nullptr);
-
-  if (!gst_cuda_result (cuda_ret)) {
-    gst_object_unref (self);
-    return nullptr;
+    if (!gst_cuda_result (cuda_ret)) {
+      gst_object_unref (self);
+      return nullptr;
+    }
   }
 
   priv->fence_waiter =
@@ -675,3 +685,145 @@ gst_cuda_d3d12_interop_upload_async (GstCudaD3D12Interop * interop,
 
   return TRUE;
 }
+
+gboolean
+gst_cuda_d3d12_interop_download_async (GstCudaD3D12Interop * interop,
+    GstBuffer * dst_d3d12, GstBuffer * src_cuda, GstCudaStream * stream)
+{
+  GstD3D12Frame frame_12;
+  GstVideoFrame frame_cuda;
+
+  auto priv = interop->priv;
+
+  if (!gst_d3d12_frame_map (&frame_12, &priv->info,
+          dst_d3d12, GST_MAP_WRITE_D3D12, GST_D3D12_FRAME_MAP_FLAG_NONE)) {
+    GST_ERROR_OBJECT (interop, "Couldn't map d3d12 buffer");
+    return FALSE;
+  }
+
+  if (!gst_d3d12_device_is_equal (priv->device, frame_12.device)) {
+    GST_WARNING_OBJECT (interop, "Different d3d12 device");
+    gst_d3d12_frame_unmap (&frame_12);
+    return FALSE;
+  }
+
+  if (!gst_video_frame_map (&frame_cuda, &priv->info, src_cuda,
+          (GstMapFlags) (GST_MAP_READ | GST_MAP_CUDA))) {
+    GST_ERROR_OBJECT (interop, "Couldn't map cuda buffer");
+    gst_d3d12_frame_unmap (&frame_12);
+    return FALSE;
+  }
+
+  GstCudaD3D12InteropResource *resource;
+  if (!gst_cuda_d3d12_interop_acquire_resource (interop, &resource)) {
+    GST_ERROR_OBJECT (interop, "Couldn't acquire resource");
+    gst_d3d12_frame_unmap (&frame_12);
+    gst_video_frame_unmap (&frame_cuda);
+    return FALSE;
+  }
+
+  if (!gst_cuda_context_push (priv->context)) {
+    GST_ERROR_OBJECT (interop, "Couldn't push context");
+    gst_d3d12_frame_unmap (&frame_12);
+    gst_video_frame_unmap (&frame_cuda);
+    return FALSE;
+  }
+
+  auto stream_handle = gst_cuda_stream_get_handle (stream);
+  for (guint i = 0; i < GST_VIDEO_FRAME_N_PLANES (&frame_cuda); i++) {
+    CUDA_MEMCPY2D copy_params = { };
+    guint8 *dst_data = (guint8 *) resource->devptr;
+
+    dst_data += priv->layout[i].Offset;
+
+    copy_params.srcMemoryType = CU_MEMORYTYPE_DEVICE;
+    copy_params.srcDevice = (CUdeviceptr)
+        GST_VIDEO_FRAME_PLANE_DATA (&frame_cuda, i);
+    copy_params.srcPitch = GST_VIDEO_FRAME_PLANE_STRIDE (&frame_cuda, i);
+
+    copy_params.dstMemoryType = CU_MEMORYTYPE_DEVICE;
+    copy_params.dstDevice = (CUdeviceptr) dst_data;
+    copy_params.dstPitch = priv->layout[i].Footprint.RowPitch;
+
+    copy_params.WidthInBytes = GST_VIDEO_FRAME_COMP_WIDTH (&frame_cuda, i) *
+        GST_VIDEO_FRAME_COMP_PSTRIDE (&frame_cuda, i);
+    copy_params.Height = GST_VIDEO_FRAME_COMP_HEIGHT (&frame_cuda, i);
+
+    auto cuda_ret = CuMemcpy2DAsync (&copy_params, stream_handle);
+    if (!gst_cuda_result (cuda_ret)) {
+      GST_ERROR_OBJECT (interop, "CuMemcpy2DAsync failed");
+      gst_video_frame_unmap (&frame_cuda);
+      gst_d3d12_frame_unmap (&frame_12);
+      gst_mini_object_unref (resource);
+      gst_cuda_context_pop (nullptr);
+
+      return FALSE;
+    }
+  }
+
+  priv->fence_val++;
+  CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS signal_params = { };
+  signal_params.params.fence.value = priv->fence_val;
+
+  auto cuda_ret = CuSignalExternalSemaphoresAsync (&priv->in_sem,
+      &signal_params, 1, stream_handle);
+  gst_cuda_context_pop (nullptr);
+  gst_video_frame_unmap (&frame_cuda);
+
+  if (!gst_cuda_result (cuda_ret)) {
+    GST_ERROR_OBJECT (interop, "CuSignalExternalSemaphoresAsync failed");
+    gst_mini_object_unref (resource);
+    priv->fence_val--;
+
+    return FALSE;
+  }
+
+  GstD3D12FenceData *fence_data;
+  gst_d3d12_fence_data_pool_acquire (priv->fence_data_pool, &fence_data);
+
+  gst_d3d12_fence_data_push (fence_data,
+      FENCE_NOTIFY_MINI_OBJECT (gst_mini_object_ref (resource)));
+
+  GstD3D12CopyTextureRegionArgs args[GST_VIDEO_MAX_PLANES] = { };
+  D3D12_BOX src_box[GST_VIDEO_MAX_PLANES] = { };
+
+  for (guint i = 0; i < GST_VIDEO_INFO_N_PLANES (&priv->info); i++) {
+    src_box[i].left = 0;
+    src_box[i].top = 0;
+    src_box[i].right = MIN (frame_12.plane_rect[i].right,
+        priv->layout[i].Footprint.Width);
+    src_box[i].bottom = MIN (frame_12.plane_rect[i].bottom,
+        priv->layout[i].Footprint.Height);
+    src_box[i].front = 0;
+    src_box[i].back = 1;
+
+    args[i].src.Type = D3D12_TEXTURE_COPY_TYPE_PLACED_FOOTPRINT;
+    args[i].src.pResource = resource->resource.Get ();
+    args[i].src.PlacedFootprint = priv->layout[i];
+
+    args[i].dst.Type = D3D12_TEXTURE_COPY_TYPE_SUBRESOURCE_INDEX;
+    args[i].dst.pResource = frame_12.data[i];
+    args[i].dst.SubresourceIndex = frame_12.subresource_index[i];
+  }
+
+  auto in_fence = priv->in_fence.Get ();
+  guint64 fence_val;
+  auto ret = gst_d3d12_device_copy_texture_region (priv->device,
+      GST_VIDEO_INFO_N_PLANES (&priv->info), args, fence_data,
+      1, &in_fence, &priv->fence_val, D3D12_COMMAND_LIST_TYPE_COMPUTE,
+      &fence_val);
+  gst_d3d12_frame_unmap (&frame_12);
+
+  if (!ret) {
+    GST_ERROR_OBJECT (interop, "Couldn't execute d3d12 copy");
+    gst_mini_object_unref (resource);
+    return FALSE;
+  }
+
+  priv->fence_waiter->wait_async (priv->fence_val, resource);
+
+  gst_d3d12_buffer_set_fence (dst_d3d12, priv->out_fence.Get (), fence_val,
+      FALSE);
+
+  return TRUE;
+}
diff --git a/subprojects/gst-plugins-bad/sys/nvcodec/gstcudainterop_d3d12.h b/subprojects/gst-plugins-bad/sys/nvcodec/gstcudainterop_d3d12.h
index 21d5324985..41d7dadcd7 100644
--- a/subprojects/gst-plugins-bad/sys/nvcodec/gstcudainterop_d3d12.h
+++ b/subprojects/gst-plugins-bad/sys/nvcodec/gstcudainterop_d3d12.h
@@ -34,12 +34,18 @@ GType gst_cuda_d3d12_interop_resource_get_type (void);
 
 GstCudaD3D12Interop * gst_cuda_d3d12_interop_new (GstCudaContext * context,
                                                   GstD3D12Device * device,
-                                                  const GstVideoInfo * info);
+                                                  const GstVideoInfo * info,
+                                                  gboolean is_uploader);
 
 gboolean gst_cuda_d3d12_interop_upload_async (GstCudaD3D12Interop * interop,
                                               GstBuffer * dst_cuda,
                                               GstBuffer * src_d3d12,
                                               GstCudaStream * stream);
 
+gboolean gst_cuda_d3d12_interop_download_async (GstCudaD3D12Interop * interop,
+                                                GstBuffer * dst_d3d12,
+                                                GstBuffer * src_cuda,
+                                                GstCudaStream * stream);
+
 G_END_DECLS
 
diff --git a/subprojects/gst-plugins-bad/sys/nvcodec/gstnvencoder.cpp b/subprojects/gst-plugins-bad/sys/nvcodec/gstnvencoder.cpp
index 7a8d565e30..75a895bb16 100644
--- a/subprojects/gst-plugins-bad/sys/nvcodec/gstnvencoder.cpp
+++ b/subprojects/gst-plugins-bad/sys/nvcodec/gstnvencoder.cpp
@@ -1315,7 +1315,7 @@ gst_nv_encoder_set_format (GstVideoEncoder * encoder,
         gst_d3d12_ensure_element_data_for_adapter_luid (GST_ELEMENT (self),
             priv->dxgi_adapter_luid, &priv->device_12)) {
       priv->interop_12 = gst_cuda_d3d12_interop_new (priv->context,
-          priv->device_12, &state->info);
+          priv->device_12, &state->info, TRUE);
     }
   }
 #endif