ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
patch_match_cuda.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include <cuda_runtime.h>
11 
12 #include <iostream>
13 #include <memory>
14 #include <vector>
15 
16 #include "mvs/cuda_texture.h"
17 #include "mvs/depth_map.h"
18 #include "mvs/gpu_mat.h"
19 #include "mvs/gpu_mat_prng.h"
20 #include "mvs/gpu_mat_ref_image.h"
21 #include "mvs/image.h"
22 #include "mvs/normal_map.h"
23 #include "mvs/patch_match.h"
24 
25 namespace colmap {
26 namespace mvs {
27 
29 public:
31  const PatchMatch::Problem& problem);
32 
33  void Run();
34 
38  std::vector<int> GetConsistentImageIdxs() const;
39 
40 private:
41  template <int kWindowSize, int kWindowStep>
42  void RunWithWindowSizeAndStep();
43 
44  void ComputeCudaConfig();
45 
46  void BindRefImageTexture();
47 
48  void InitRefImage();
49  void InitSourceImages();
50  void InitTransforms();
51  void InitWorkspaceMemory();
52 
53  // Rotate reference image by 90 degrees in counter-clockwise direction.
54  void Rotate();
55 
56  const PatchMatchOptions options_;
57  const PatchMatch::Problem problem_;
58 
59  // Dimensions for sweeping from top to bottom, i.e. one thread per column.
60  dim3 sweep_block_size_;
61  dim3 sweep_grid_size_;
62  // Dimensions for element-wise operations, i.e. one thread per pixel.
63  dim3 elem_wise_block_size_;
64  dim3 elem_wise_grid_size_;
65 
66  // Original (not rotated) dimension of reference image.
67  size_t ref_width_;
68  size_t ref_height_;
69 
70  // Rotation of reference image in pi/2. This is equivalent to the number of
71  // calls to `rotate` mod 4.
72  int rotation_in_half_pi_;
73 
74  // Reference and source image input data.
75  std::unique_ptr<CudaArrayLayeredTexture<uint8_t>> ref_image_texture_;
76  std::unique_ptr<CudaArrayLayeredTexture<uint8_t>> src_images_texture_;
77  std::unique_ptr<CudaArrayLayeredTexture<float>> src_depth_maps_texture_;
78 
79  // Relative poses from rotated versions of reference image to source images
80  // corresponding to _rotationInHalfPi:
81  //
82  // [S(1), S(2), S(3), ..., S(n)]
83  //
84  // where n is the number of source images and:
85  //
86  // S(i) = [K_i(0, 0), K_i(0, 2), K_i(1, 1), K_i(1, 2), R_i(:), T_i(:)
87  // C_i(:), P(:), P^-1(:)]
88  //
89  // where i denotes the index of the source image and K is its calibration.
90  // R, T, C, P, P^-1 denote the relative rotation, translation, camera
91  // center, projection, and inverse projection from there reference to the
92  // i-th source image.
93  std::unique_ptr<CudaArrayLayeredTexture<float>> poses_texture_[4];
94 
95  // Calibration matrix for rotated versions of reference image
96  // as {K[0, 0], K[0, 2], K[1, 1], K[1, 2]} corresponding to
97  // _rotationInHalfPi.
98  float ref_K_host_[4][4];
99  float ref_inv_K_host_[4][4];
100 
101  // Data for reference image.
102  std::unique_ptr<GpuMatRefImage> ref_image_;
103  std::unique_ptr<GpuMat<float>> depth_map_;
104  std::unique_ptr<GpuMat<float>> normal_map_;
105  std::unique_ptr<GpuMat<float>> sel_prob_map_;
106  std::unique_ptr<GpuMat<float>> prev_sel_prob_map_;
107  std::unique_ptr<GpuMat<float>> cost_map_;
108  std::unique_ptr<GpuMatPRNG> rand_state_map_;
109  std::unique_ptr<GpuMat<uint8_t>> consistency_mask_;
110 
111  // Shared memory is too small to hold local state for each thread,
112  // so this is workspace memory in global memory.
113  std::unique_ptr<GpuMat<float>> global_workspace_;
114 };
115 
116 } // namespace mvs
117 } // namespace colmap
NormalMap GetNormalMap() const
Mat< float > GetSelProbMap() const
PatchMatchCuda(const PatchMatchOptions &options, const PatchMatch::Problem &problem)
std::vector< int > GetConsistentImageIdxs() const
DepthMap GetDepthMap() const