16 #include <unordered_map>
21 #ifdef BUILD_CUDA_MODULE
31 template <
typename Block>
34 const std::shared_ptr<Block>& rhs)
const {
35 if (lhs->byte_size_ != rhs->byte_size_) {
36 return lhs->byte_size_ < rhs->byte_size_;
38 return lhs->ptr_ < rhs->ptr_;
42 template <
typename Block>
45 const std::shared_ptr<Block>& rhs)
const {
46 if (lhs->ptr_ != rhs->ptr_) {
47 return lhs->ptr_ < rhs->ptr_;
49 return lhs->byte_size_ < rhs->byte_size_;
58 const std::weak_ptr<RealBlock>& r_block)
87 return ((byte_size + alignment - 1) / alignment) * alignment;
93 std::lock_guard<std::recursive_mutex> lock(mutex_);
95 auto free_block = ExtractFreeBlock(byte_size);
97 if (free_block !=
nullptr) {
98 size_t remaining_size = free_block->byte_size_ - byte_size;
100 if (remaining_size == 0) {
102 allocated_virtual_blocks_.emplace(free_block->ptr_, free_block);
104 return free_block->ptr_;
107 auto new_block = std::make_shared<VirtualBlock>(
108 free_block->ptr_, byte_size, free_block->r_block_);
109 auto remaining_block = std::make_shared<VirtualBlock>(
110 static_cast<char*
>(free_block->ptr_) + byte_size,
111 remaining_size, free_block->r_block_);
114 auto real_block = free_block->r_block_.lock();
115 real_block->v_blocks_.erase(free_block);
116 real_block->v_blocks_.insert(new_block);
117 real_block->v_blocks_.insert(remaining_block);
119 allocated_virtual_blocks_.emplace(new_block->ptr_, new_block);
120 free_virtual_blocks_.insert(remaining_block);
122 return new_block->ptr_;
133 std::lock_guard<std::recursive_mutex> lock(mutex_);
135 auto ptr_it = allocated_virtual_blocks_.find(ptr);
137 if (ptr_it == allocated_virtual_blocks_.end()) {
143 auto v_block = ptr_it->second;
144 allocated_virtual_blocks_.erase(ptr_it);
146 auto r_block = v_block->r_block_.lock();
147 auto& v_block_set = r_block->v_blocks_;
149 const auto v_block_it = v_block_set.find(v_block);
150 if (v_block_it == v_block_set.end()) {
152 "Virtual block ({} @ {} bytes) not recorded in real block "
154 fmt::ptr(v_block->ptr_), v_block->byte_size_,
155 fmt::ptr(r_block->ptr_), r_block->byte_size_);
158 auto merged_v_block = v_block;
161 if (v_block_it != v_block_set.begin()) {
163 auto v_block_it_copy = v_block_it;
164 auto v_block_it_prev = --v_block_it_copy;
166 auto v_block_prev = *v_block_it_prev;
168 if (free_virtual_blocks_.find(v_block_prev) !=
169 free_virtual_blocks_.end()) {
171 merged_v_block = std::make_shared<VirtualBlock>(
173 v_block_prev->byte_size_ + merged_v_block->byte_size_,
177 v_block_set.erase(v_block_prev);
178 free_virtual_blocks_.erase(v_block_prev);
185 auto v_block_it_copy = v_block_it;
186 auto v_block_it_next = ++v_block_it_copy;
188 if (v_block_it_next != v_block_set.end()) {
189 auto v_block_next = *v_block_it_next;
191 if (free_virtual_blocks_.find(v_block_next) !=
192 free_virtual_blocks_.end()) {
194 merged_v_block = std::make_shared<VirtualBlock>(
195 merged_v_block->ptr_,
196 merged_v_block->byte_size_ + v_block_next->byte_size_,
200 v_block_set.erase(v_block_next);
201 free_virtual_blocks_.erase(v_block_next);
205 v_block_set.erase(v_block);
206 v_block_set.insert(merged_v_block);
207 free_virtual_blocks_.insert(merged_v_block);
213 const std::shared_ptr<MemoryManagerDevice>& device_mm) {
214 std::lock_guard<std::recursive_mutex> lock(mutex_);
216 auto r_block = std::make_shared<RealBlock>(ptr, byte_size);
217 auto v_block = std::make_shared<VirtualBlock>(ptr, byte_size, r_block);
218 r_block->device_mm_ = device_mm;
219 r_block->v_blocks_.insert(v_block);
221 real_blocks_.insert(r_block);
222 allocated_virtual_blocks_.emplace(v_block->ptr_, v_block);
230 std::vector<std::pair<void*, std::shared_ptr<MemoryManagerDevice>>>
Release(
232 std::lock_guard<std::recursive_mutex> lock(mutex_);
236 releasable_real_blocks;
238 real_blocks_.begin(), real_blocks_.end(),
239 std::inserter(releasable_real_blocks,
240 releasable_real_blocks.begin()),
241 [
this](
const auto& r_block) { return IsReleasable(r_block); });
244 std::vector<std::pair<void*, std::shared_ptr<MemoryManagerDevice>>>
246 size_t released_size = 0;
247 while (!releasable_real_blocks.empty() && released_size < byte_size) {
248 size_t remaining_size = byte_size - released_size;
250 std::make_shared<RealBlock>(
nullptr, remaining_size);
251 auto it = releasable_real_blocks.lower_bound(query_size);
252 if (it == releasable_real_blocks.end()) {
257 real_blocks_.erase(r_block);
258 for (
const auto& v_block : r_block->v_blocks_) {
259 free_virtual_blocks_.erase(v_block);
262 releasable_real_blocks.erase(r_block);
263 released_pointers.emplace_back(r_block->ptr_, r_block->device_mm_);
264 released_size += r_block->byte_size_;
267 return released_pointers;
271 std::vector<std::pair<void*, std::shared_ptr<MemoryManagerDevice>>>
277 size_t Size()
const {
return real_blocks_.size(); }
287 std::shared_ptr<VirtualBlock> ExtractFreeBlock(
size_t byte_size) {
288 std::lock_guard<std::recursive_mutex> lock(mutex_);
290 size_t max_byte_size =
static_cast<size_t>(
291 kMaxFragmentation *
static_cast<double>(byte_size));
295 auto query_size = std::make_shared<VirtualBlock>(
296 nullptr, byte_size, std::weak_ptr<RealBlock>());
297 auto it = free_virtual_blocks_.lower_bound(query_size);
298 while (it != free_virtual_blocks_.end() &&
299 (*it)->byte_size_ <= max_byte_size) {
300 auto r_block = (*it)->r_block_.lock();
301 if (r_block->byte_size_ <= max_byte_size) {
303 free_virtual_blocks_.erase(it);
313 bool IsReleasable(
const std::shared_ptr<RealBlock>& r_block) {
314 if (r_block->v_blocks_.size() != 1) {
318 auto v_block = *(r_block->v_blocks_.begin());
319 if (r_block->ptr_ != v_block->ptr_ ||
320 r_block->byte_size_ != v_block->byte_size_) {
322 "Real block {} @ {} bytes has single "
323 "virtual block {} @ {} bytes",
324 fmt::ptr(r_block->ptr_), r_block->byte_size_,
325 fmt::ptr(v_block->ptr_), v_block->byte_size_);
328 return free_virtual_blocks_.find(v_block) != free_virtual_blocks_.end();
332 const double kMaxFragmentation = 4.0;
334 std::set<std::shared_ptr<RealBlock>, SizeOrder<RealBlock>> real_blocks_;
336 std::unordered_map<void*, std::shared_ptr<VirtualBlock>>
337 allocated_virtual_blocks_;
338 std::set<std::shared_ptr<VirtualBlock>, SizeOrder<VirtualBlock>>
339 free_virtual_blocks_;
341 std::recursive_mutex mutex_;
353 #ifdef BUILD_CUDA_MODULE
363 for (
const auto& cache_pair : device_caches_) {
365 const auto& device = cache_pair.first;
366 const auto& cache = cache_pair.second;
370 if (!cache.Empty()) {
372 cache.Size(), device.ToString());
382 const std::shared_ptr<MemoryManagerDevice>& device_mm) {
388 void* ptr = device_caches_.at(device).Malloc(internal_byte_size);
389 if (ptr !=
nullptr) {
395 ptr = device_mm->Malloc(internal_byte_size, device);
396 }
catch (
const std::runtime_error&) {
400 if (ptr ==
nullptr) {
402 device_caches_.at(device).Release(internal_byte_size);
403 for (
const auto& old_pair : old_ptrs) {
405 const auto& old_ptr = old_pair.first;
406 const auto& old_device_mm = old_pair.second;
408 old_device_mm->Free(old_ptr, device);
412 ptr = device_mm->Malloc(internal_byte_size, device);
415 device_caches_.at(device).Acquire(ptr, internal_byte_size, device_mm);
423 device_caches_.at(device).Free(ptr);
429 auto old_ptrs = device_caches_.at(device).ReleaseAll();
430 for (
const auto& old_pair : old_ptrs) {
432 const auto& old_ptr = old_pair.first;
433 const auto& old_device_mm = old_pair.second;
435 old_device_mm->Free(old_ptr, device);
443 std::vector<Device> devices;
445 std::lock_guard<std::recursive_mutex> lock(init_mutex_);
446 for (
const auto& cache_pair : device_caches_) {
447 devices.push_back(cache_pair.first);
451 for (
const auto& device : devices) {
461 void Init(
const Device& device) {
462 std::lock_guard<std::recursive_mutex> lock(init_mutex_);
465 device_caches_.emplace(std::piecewise_construct,
466 std::forward_as_tuple(device),
467 std::forward_as_tuple());
470 std::unordered_map<Device, MemoryCache> device_caches_;
471 std::recursive_mutex init_mutex_;
475 const std::shared_ptr<MemoryManagerDevice>& device_mm)
476 : device_mm_(device_mm) {
477 if (std::dynamic_pointer_cast<MemoryManagerCached>(
device_mm_) !=
nullptr) {
479 "An instance of type MemoryManagerCached as the underlying "
480 "non-cached manager is forbidden.");
485 if (byte_size == 0) {
493 if (ptr ==
nullptr) {
505 device_mm_->Memcpy(dst_ptr, dst_device, src_ptr, src_device, num_bytes);
void Free(void *ptr, const Device &device)
void Clear(const Device &device)
static Cacher & GetInstance()
Cacher(const Cacher &)=delete
Cacher & operator=(Cacher &)=delete
void * Malloc(size_t byte_size, const Device &device, const std::shared_ptr< MemoryManagerDevice > &device_mm)
MemoryCache & operator=(const MemoryCache &)=delete
MemoryCache(const MemoryCache &)=delete
void Acquire(void *ptr, size_t byte_size, const std::shared_ptr< MemoryManagerDevice > &device_mm)
Acquires ownership of the new real allocated blocks.
static size_t AlignByteSize(size_t byte_size, size_t alignment=8)
void * Malloc(size_t byte_size)
bool Empty() const
True if the set of allocated real blocks is empty, false otherwise.
std::vector< std::pair< void *, std::shared_ptr< MemoryManagerDevice > > > Release(size_t byte_size)
size_t Size() const
Returns the number of allocated real blocks.
std::vector< std::pair< void *, std::shared_ptr< MemoryManagerDevice > > > ReleaseAll()
Releases ownership of all unused real allocated blocks.
void Memcpy(void *dst_ptr, const Device &dst_device, const void *src_ptr, const Device &src_device, size_t num_bytes) override
MemoryManagerCached(const std::shared_ptr< MemoryManagerDevice > &device_mm)
static void ReleaseCache()
std::shared_ptr< MemoryManagerDevice > device_mm_
void * Malloc(size_t byte_size, const Device &device) override
void Free(void *ptr, const Device &device) override
Frees previously allocated memory at address ptr on device device.
static Logger & GetInstance()
Get Logger global singleton instance.
ccGuiPythonInstance * GetInstance() noexcept
Generic file read and write utility for python interface.
bool operator()(const std::shared_ptr< Block > &lhs, const std::shared_ptr< Block > &rhs) const
RealBlock(void *ptr, size_t byte_size)
std::set< std::shared_ptr< VirtualBlock >, PointerOrder< VirtualBlock > > v_blocks_
std::shared_ptr< MemoryManagerDevice > device_mm_
bool operator()(const std::shared_ptr< Block > &lhs, const std::shared_ptr< Block > &rhs) const
VirtualBlock(void *ptr, size_t byte_size, const std::weak_ptr< RealBlock > &r_block)
std::weak_ptr< RealBlock > r_block_