patch_origins now gets copied to the device, also changed addresses in kernel call to device addresses

This commit is contained in:
Fabian Seuring 2022-07-29 12:12:25 +02:00
parent da55294019
commit 29d2915aa9
1 changed files with 8 additions and 3 deletions

11
main.cu
View File

@ -43,14 +43,17 @@ int main(int argc, char **argv) {
src[i] = i; src[i] = i;
} }
int* patch_origins_device;
uint8_t* src_device; uint8_t* src_device;
float* dst_device; float* dst_device;
cudaMalloc(&patch_origins_device, BATCH_SIZE * 2 * sizeof(int));
cudaMalloc(&src_device, src_size * sizeof(uint8_t)); cudaMalloc(&src_device, src_size * sizeof(uint8_t));
cudaMalloc(&dst_device, dst_size * sizeof(float)); cudaMalloc(&dst_device, dst_size * sizeof(float));
cudaMemcpy(patch_origins_device, &patch_origins, BATCH_SIZE * 2 * sizeof(int), cudaMemcpyHostToDevice);
cudaMemcpy(src_device, &src, src_size * sizeof(uint8_t), cudaMemcpyHostToDevice); cudaMemcpy(src_device, &src, src_size * sizeof(uint8_t), cudaMemcpyHostToDevice);
preprocess_kernel_img_to_batch(src, GPU_DST_COLS, GPU_DST_ROWS, dst, INPUT_W, INPUT_H, HORIZONTAL_PATCHES * VERTICAL_PATCHES, patch_origins, 0); preprocess_kernel_img_to_batch(src_device, GPU_DST_COLS, GPU_DST_ROWS, dst_device, INPUT_W, INPUT_H, HORIZONTAL_PATCHES * VERTICAL_PATCHES, patch_origins_device, 0);
cudaMemcpy(&dst, dst_device, dst_size * sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(&dst, dst_device, dst_size * sizeof(float), cudaMemcpyDeviceToHost);
@ -94,8 +97,10 @@ void preprocess_kernel_img_to_batch(
float* dst, int dst_width, int dst_height, int batch_size, int* patch_origins, float* dst, int dst_width, int dst_height, int batch_size, int* patch_origins,
cudaStream_t stream) { cudaStream_t stream) {
dim3 block(BLKX, BLKY, BLKZ); // dim3 block(BLKX, BLKY, BLKZ);
dim3 grid(3 * batch_size / BLKX, dst_width / BLKY, dst_height / BLKZ); // dim3 grid(3 * batch_size / BLKX, dst_width / BLKY, dst_height / BLKZ);
dim3 block(1, 1, 1);
dim3 grid(4, 4, 3);
batching_kernel<<<grid, block, 0>>>( batching_kernel<<<grid, block, 0>>>(
src, src_width*3, src_width, src, src_width*3, src_width,