ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
MemoryManagerCached.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include <Logging.h>
9 
10 #include <algorithm>
11 #include <cstdlib>
12 #include <limits>
13 #include <memory>
14 #include <mutex>
15 #include <set>
16 #include <unordered_map>
17 #include <vector>
18 
20 
21 #ifdef BUILD_CUDA_MODULE
23 #endif
24 
25 namespace cloudViewer {
26 namespace core {
27 
28 // This implementation is insipred by PyTorch's CUDA memory manager.
29 // Reference: https://git.io/JUqUA
30 
31 template <typename Block>
32 struct SizeOrder {
33  bool operator()(const std::shared_ptr<Block>& lhs,
34  const std::shared_ptr<Block>& rhs) const {
35  if (lhs->byte_size_ != rhs->byte_size_) {
36  return lhs->byte_size_ < rhs->byte_size_;
37  }
38  return lhs->ptr_ < rhs->ptr_;
39  }
40 };
41 
42 template <typename Block>
43 struct PointerOrder {
44  bool operator()(const std::shared_ptr<Block>& lhs,
45  const std::shared_ptr<Block>& rhs) const {
46  if (lhs->ptr_ != rhs->ptr_) {
47  return lhs->ptr_ < rhs->ptr_;
48  }
49  return lhs->byte_size_ < rhs->byte_size_;
50  }
51 };
52 
53 struct RealBlock;
54 
55 struct VirtualBlock {
56  VirtualBlock(void* ptr,
57  size_t byte_size,
58  const std::weak_ptr<RealBlock>& r_block)
59  : ptr_(ptr), byte_size_(byte_size), r_block_(r_block) {}
60 
61  void* ptr_ = nullptr;
62  size_t byte_size_ = 0;
63 
64  std::weak_ptr<RealBlock> r_block_;
65 };
66 
67 struct RealBlock {
68  RealBlock(void* ptr, size_t byte_size) : ptr_(ptr), byte_size_(byte_size) {}
69 
70  void* ptr_ = nullptr;
71  size_t byte_size_ = 0;
72 
73  std::shared_ptr<MemoryManagerDevice> device_mm_;
74  std::set<std::shared_ptr<VirtualBlock>, PointerOrder<VirtualBlock>>
76 };
77 
78 class MemoryCache {
79 public:
80  MemoryCache() = default;
81  MemoryCache(const MemoryCache&) = delete;
82  MemoryCache& operator=(const MemoryCache&) = delete;
83 
86  static size_t AlignByteSize(size_t byte_size, size_t alignment = 8) {
87  return ((byte_size + alignment - 1) / alignment) * alignment;
88  }
89 
92  void* Malloc(size_t byte_size) {
93  std::lock_guard<std::recursive_mutex> lock(mutex_);
94 
95  auto free_block = ExtractFreeBlock(byte_size);
96 
97  if (free_block != nullptr) {
98  size_t remaining_size = free_block->byte_size_ - byte_size;
99 
100  if (remaining_size == 0) {
101  // No update of real block required for perfect fit.
102  allocated_virtual_blocks_.emplace(free_block->ptr_, free_block);
103 
104  return free_block->ptr_;
105  } else {
106  // Split virtual block.
107  auto new_block = std::make_shared<VirtualBlock>(
108  free_block->ptr_, byte_size, free_block->r_block_);
109  auto remaining_block = std::make_shared<VirtualBlock>(
110  static_cast<char*>(free_block->ptr_) + byte_size,
111  remaining_size, free_block->r_block_);
112 
113  // Update real block.
114  auto real_block = free_block->r_block_.lock();
115  real_block->v_blocks_.erase(free_block);
116  real_block->v_blocks_.insert(new_block);
117  real_block->v_blocks_.insert(remaining_block);
118 
119  allocated_virtual_blocks_.emplace(new_block->ptr_, new_block);
120  free_virtual_blocks_.insert(remaining_block);
121 
122  return new_block->ptr_;
123  }
124  }
125 
126  return nullptr;
127  }
128 
132  void Free(void* ptr) {
133  std::lock_guard<std::recursive_mutex> lock(mutex_);
134 
135  auto ptr_it = allocated_virtual_blocks_.find(ptr);
136 
137  if (ptr_it == allocated_virtual_blocks_.end()) {
138  // Should never reach here
139  utility::LogError("Block of {} should have been recorded.",
140  fmt::ptr(ptr));
141  }
142 
143  auto v_block = ptr_it->second;
144  allocated_virtual_blocks_.erase(ptr_it);
145 
146  auto r_block = v_block->r_block_.lock();
147  auto& v_block_set = r_block->v_blocks_;
148 
149  const auto v_block_it = v_block_set.find(v_block);
150  if (v_block_it == v_block_set.end()) {
152  "Virtual block ({} @ {} bytes) not recorded in real block "
153  "{} @ {} bytes.",
154  fmt::ptr(v_block->ptr_), v_block->byte_size_,
155  fmt::ptr(r_block->ptr_), r_block->byte_size_);
156  }
157 
158  auto merged_v_block = v_block;
159 
160  // Merge with previous block.
161  if (v_block_it != v_block_set.begin()) {
162  // Use copy to keep original iterator unchanged.
163  auto v_block_it_copy = v_block_it;
164  auto v_block_it_prev = --v_block_it_copy;
165 
166  auto v_block_prev = *v_block_it_prev;
167 
168  if (free_virtual_blocks_.find(v_block_prev) !=
169  free_virtual_blocks_.end()) {
170  // Update merged block.
171  merged_v_block = std::make_shared<VirtualBlock>(
172  v_block_prev->ptr_,
173  v_block_prev->byte_size_ + merged_v_block->byte_size_,
174  r_block);
175 
176  // Remove from sets.
177  v_block_set.erase(v_block_prev);
178  free_virtual_blocks_.erase(v_block_prev);
179  }
180  }
181 
182  // Merge with next block.
183 
184  // Use copy to keep original iterator unchanged.
185  auto v_block_it_copy = v_block_it;
186  auto v_block_it_next = ++v_block_it_copy;
187 
188  if (v_block_it_next != v_block_set.end()) {
189  auto v_block_next = *v_block_it_next;
190 
191  if (free_virtual_blocks_.find(v_block_next) !=
192  free_virtual_blocks_.end()) {
193  // Update merged block.
194  merged_v_block = std::make_shared<VirtualBlock>(
195  merged_v_block->ptr_,
196  merged_v_block->byte_size_ + v_block_next->byte_size_,
197  r_block);
198 
199  // Remove from sets.
200  v_block_set.erase(v_block_next);
201  free_virtual_blocks_.erase(v_block_next);
202  }
203  }
204 
205  v_block_set.erase(v_block);
206  v_block_set.insert(merged_v_block);
207  free_virtual_blocks_.insert(merged_v_block);
208  }
209 
211  void Acquire(void* ptr,
212  size_t byte_size,
213  const std::shared_ptr<MemoryManagerDevice>& device_mm) {
214  std::lock_guard<std::recursive_mutex> lock(mutex_);
215 
216  auto r_block = std::make_shared<RealBlock>(ptr, byte_size);
217  auto v_block = std::make_shared<VirtualBlock>(ptr, byte_size, r_block);
218  r_block->device_mm_ = device_mm;
219  r_block->v_blocks_.insert(v_block);
220 
221  real_blocks_.insert(r_block);
222  allocated_virtual_blocks_.emplace(v_block->ptr_, v_block);
223  }
224 
230  std::vector<std::pair<void*, std::shared_ptr<MemoryManagerDevice>>> Release(
231  size_t byte_size) {
232  std::lock_guard<std::recursive_mutex> lock(mutex_);
233 
234  // Filter releasable blocks.
235  std::set<std::shared_ptr<RealBlock>, SizeOrder<RealBlock>>
236  releasable_real_blocks;
237  std::copy_if(
238  real_blocks_.begin(), real_blocks_.end(),
239  std::inserter(releasable_real_blocks,
240  releasable_real_blocks.begin()),
241  [this](const auto& r_block) { return IsReleasable(r_block); });
242 
243  // Determine greedy "minimal" subset
244  std::vector<std::pair<void*, std::shared_ptr<MemoryManagerDevice>>>
245  released_pointers;
246  size_t released_size = 0;
247  while (!releasable_real_blocks.empty() && released_size < byte_size) {
248  size_t remaining_size = byte_size - released_size;
249  auto query_size =
250  std::make_shared<RealBlock>(nullptr, remaining_size);
251  auto it = releasable_real_blocks.lower_bound(query_size);
252  if (it == releasable_real_blocks.end()) {
253  --it;
254  }
255  auto r_block = *it;
256 
257  real_blocks_.erase(r_block);
258  for (const auto& v_block : r_block->v_blocks_) {
259  free_virtual_blocks_.erase(v_block);
260  }
261 
262  releasable_real_blocks.erase(r_block);
263  released_pointers.emplace_back(r_block->ptr_, r_block->device_mm_);
264  released_size += r_block->byte_size_;
265  }
266 
267  return released_pointers;
268  }
269 
271  std::vector<std::pair<void*, std::shared_ptr<MemoryManagerDevice>>>
274  }
275 
277  size_t Size() const { return real_blocks_.size(); }
278 
280  bool Empty() const { return Size() == 0; }
281 
282 private:
287  std::shared_ptr<VirtualBlock> ExtractFreeBlock(size_t byte_size) {
288  std::lock_guard<std::recursive_mutex> lock(mutex_);
289 
290  size_t max_byte_size = static_cast<size_t>(
291  kMaxFragmentation * static_cast<double>(byte_size));
292 
293  // Consider blocks with size in range
294  // [byte_size, max_byte_size].
295  auto query_size = std::make_shared<VirtualBlock>(
296  nullptr, byte_size, std::weak_ptr<RealBlock>());
297  auto it = free_virtual_blocks_.lower_bound(query_size);
298  while (it != free_virtual_blocks_.end() &&
299  (*it)->byte_size_ <= max_byte_size) {
300  auto r_block = (*it)->r_block_.lock();
301  if (r_block->byte_size_ <= max_byte_size) {
302  auto block = *it;
303  free_virtual_blocks_.erase(it);
304  return block;
305  }
306  ++it;
307  }
308 
309  return nullptr;
310  }
311 
313  bool IsReleasable(const std::shared_ptr<RealBlock>& r_block) {
314  if (r_block->v_blocks_.size() != 1) {
315  return false;
316  }
317 
318  auto v_block = *(r_block->v_blocks_.begin());
319  if (r_block->ptr_ != v_block->ptr_ ||
320  r_block->byte_size_ != v_block->byte_size_) {
322  "Real block {} @ {} bytes has single "
323  "virtual block {} @ {} bytes",
324  fmt::ptr(r_block->ptr_), r_block->byte_size_,
325  fmt::ptr(v_block->ptr_), v_block->byte_size_);
326  }
327 
328  return free_virtual_blocks_.find(v_block) != free_virtual_blocks_.end();
329  }
330 
332  const double kMaxFragmentation = 4.0;
333 
334  std::set<std::shared_ptr<RealBlock>, SizeOrder<RealBlock>> real_blocks_;
335 
336  std::unordered_map<void*, std::shared_ptr<VirtualBlock>>
337  allocated_virtual_blocks_;
338  std::set<std::shared_ptr<VirtualBlock>, SizeOrder<VirtualBlock>>
339  free_virtual_blocks_;
340 
341  std::recursive_mutex mutex_;
342 };
343 
344 class Cacher {
345 public:
346  static Cacher& GetInstance() {
347  // Ensure the static Logger instance is instantiated before the
348  // Cacher instance.
349  // Since destruction of static instances happens in reverse order,
350  // this guarantees that the Logger can be used at any point in time.
352 
353 #ifdef BUILD_CUDA_MODULE
354  // Ensure CUDAState is initialized before Cacher.
356 #endif
357 
358  static Cacher instance;
359  return instance;
360  }
361 
363  for (const auto& cache_pair : device_caches_) {
364  // Simulate C++17 structured bindings for better readability.
365  const auto& device = cache_pair.first;
366  const auto& cache = cache_pair.second;
367 
368  Clear(device);
369 
370  if (!cache.Empty()) {
371  utility::LogError("{} leaking memory blocks on {}",
372  cache.Size(), device.ToString());
373  }
374  }
375  }
376 
377  Cacher(const Cacher&) = delete;
378  Cacher& operator=(Cacher&) = delete;
379 
380  void* Malloc(size_t byte_size,
381  const Device& device,
382  const std::shared_ptr<MemoryManagerDevice>& device_mm) {
383  Init(device);
384 
385  size_t internal_byte_size = MemoryCache::AlignByteSize(byte_size);
386 
387  // Malloc from cache.
388  void* ptr = device_caches_.at(device).Malloc(internal_byte_size);
389  if (ptr != nullptr) {
390  return ptr;
391  }
392 
393  // Malloc from real memory manager.
394  try {
395  ptr = device_mm->Malloc(internal_byte_size, device);
396  } catch (const std::runtime_error&) {
397  }
398 
399  // Free cached memory and try again.
400  if (ptr == nullptr) {
401  auto old_ptrs =
402  device_caches_.at(device).Release(internal_byte_size);
403  for (const auto& old_pair : old_ptrs) {
404  // Simulate C++17 structured bindings for better readability.
405  const auto& old_ptr = old_pair.first;
406  const auto& old_device_mm = old_pair.second;
407 
408  old_device_mm->Free(old_ptr, device);
409  }
410 
411  // Do not catch the error if the allocation still fails.
412  ptr = device_mm->Malloc(internal_byte_size, device);
413  }
414 
415  device_caches_.at(device).Acquire(ptr, internal_byte_size, device_mm);
416 
417  return ptr;
418  }
419 
420  void Free(void* ptr, const Device& device) {
421  Init(device);
422 
423  device_caches_.at(device).Free(ptr);
424  }
425 
426  void Clear(const Device& device) {
427  Init(device);
428 
429  auto old_ptrs = device_caches_.at(device).ReleaseAll();
430  for (const auto& old_pair : old_ptrs) {
431  // Simulate C++17 structured bindings for better readability.
432  const auto& old_ptr = old_pair.first;
433  const auto& old_device_mm = old_pair.second;
434 
435  old_device_mm->Free(old_ptr, device);
436  }
437  }
438 
439  void Clear() {
440  // Collect all devices in a thread-safe manner. This avoids potential
441  // issues with newly initialized/inserted elements while iterating over
442  // the container.
443  std::vector<Device> devices;
444  {
445  std::lock_guard<std::recursive_mutex> lock(init_mutex_);
446  for (const auto& cache_pair : device_caches_) {
447  devices.push_back(cache_pair.first);
448  }
449  }
450 
451  for (const auto& device : devices) {
452  Clear(device);
453  }
454  }
455 
456 private:
457  Cacher() = default;
458 
461  void Init(const Device& device) {
462  std::lock_guard<std::recursive_mutex> lock(init_mutex_);
463 
464  // Performs no action if already initialized.
465  device_caches_.emplace(std::piecewise_construct,
466  std::forward_as_tuple(device),
467  std::forward_as_tuple());
468  }
469 
470  std::unordered_map<Device, MemoryCache> device_caches_;
471  std::recursive_mutex init_mutex_;
472 };
473 
475  const std::shared_ptr<MemoryManagerDevice>& device_mm)
476  : device_mm_(device_mm) {
477  if (std::dynamic_pointer_cast<MemoryManagerCached>(device_mm_) != nullptr) {
479  "An instance of type MemoryManagerCached as the underlying "
480  "non-cached manager is forbidden.");
481  }
482 }
483 
484 void* MemoryManagerCached::Malloc(size_t byte_size, const Device& device) {
485  if (byte_size == 0) {
486  return nullptr;
487  }
488 
489  return Cacher::GetInstance().Malloc(byte_size, device, device_mm_);
490 }
491 
492 void MemoryManagerCached::Free(void* ptr, const Device& device) {
493  if (ptr == nullptr) {
494  return;
495  }
496 
497  Cacher::GetInstance().Free(ptr, device);
498 }
499 
500 void MemoryManagerCached::Memcpy(void* dst_ptr,
501  const Device& dst_device,
502  const void* src_ptr,
503  const Device& src_device,
504  size_t num_bytes) {
505  device_mm_->Memcpy(dst_ptr, dst_device, src_ptr, src_device, num_bytes);
506 }
507 
509  Cacher::GetInstance().Clear(device);
510 }
511 
513 
514 } // namespace core
515 } // namespace cloudViewer
Common CUDA utilities.
void Free(void *ptr, const Device &device)
void Clear(const Device &device)
Cacher(const Cacher &)=delete
Cacher & operator=(Cacher &)=delete
void * Malloc(size_t byte_size, const Device &device, const std::shared_ptr< MemoryManagerDevice > &device_mm)
MemoryCache & operator=(const MemoryCache &)=delete
MemoryCache(const MemoryCache &)=delete
void Acquire(void *ptr, size_t byte_size, const std::shared_ptr< MemoryManagerDevice > &device_mm)
Acquires ownership of the new real allocated blocks.
static size_t AlignByteSize(size_t byte_size, size_t alignment=8)
void * Malloc(size_t byte_size)
bool Empty() const
True if the set of allocated real blocks is empty, false otherwise.
std::vector< std::pair< void *, std::shared_ptr< MemoryManagerDevice > > > Release(size_t byte_size)
size_t Size() const
Returns the number of allocated real blocks.
std::vector< std::pair< void *, std::shared_ptr< MemoryManagerDevice > > > ReleaseAll()
Releases ownership of all unused real allocated blocks.
void Memcpy(void *dst_ptr, const Device &dst_device, const void *src_ptr, const Device &src_device, size_t num_bytes) override
MemoryManagerCached(const std::shared_ptr< MemoryManagerDevice > &device_mm)
std::shared_ptr< MemoryManagerDevice > device_mm_
void * Malloc(size_t byte_size, const Device &device) override
void Free(void *ptr, const Device &device) override
Frees previously allocated memory at address ptr on device device.
static Logger & GetInstance()
Get Logger global singleton instance.
Definition: Logging.cpp:25
#define LogError(...)
Definition: Logging.h:60
int max(int a, int b)
Definition: cutil_math.h:48
ccGuiPythonInstance * GetInstance() noexcept
Definition: Runtime.cpp:72
Generic file read and write utility for python interface.
bool operator()(const std::shared_ptr< Block > &lhs, const std::shared_ptr< Block > &rhs) const
RealBlock(void *ptr, size_t byte_size)
std::set< std::shared_ptr< VirtualBlock >, PointerOrder< VirtualBlock > > v_blocks_
std::shared_ptr< MemoryManagerDevice > device_mm_
bool operator()(const std::shared_ptr< Block > &lhs, const std::shared_ptr< Block > &rhs) const
VirtualBlock(void *ptr, size_t byte_size, const std::weak_ptr< RealBlock > &r_block)
std::weak_ptr< RealBlock > r_block_