21 const Dtype& key_dtype,
23 const Dtype& value_dtype,
27 : key_dtype_(key_dtype),
28 key_element_shape_(key_element_shape),
29 dtypes_value_({value_dtype}),
30 element_shapes_value_({value_element_shape}) {
31 Init(init_capacity, device, backend);
35 const Dtype& key_dtype,
37 const std::vector<Dtype>& dtypes_value,
38 const std::vector<SizeVector>& element_shapes_value,
41 : key_dtype_(key_dtype),
42 key_element_shape_(key_element_shape),
43 dtypes_value_(dtypes_value),
44 element_shapes_value_(element_shapes_value) {
45 Init(init_capacity, device, backend);
50 if (capacity <=
count) {
56 std::vector<Tensor> active_values;
64 for (
auto& value_buffer : value_buffers) {
65 active_values.emplace_back(value_buffer.IndexGet({active_indices}));
69 device_hashmap_->Free();
70 device_hashmap_->Allocate(capacity);
71 device_hashmap_->Reserve(capacity);
74 Tensor output_buf_indices, output_masks;
75 InsertImpl(active_keys, active_values, output_buf_indices,
81 const Tensor& input_values) {
82 Tensor output_buf_indices, output_masks;
83 Insert(input_keys, input_values, output_buf_indices, output_masks);
88 const Tensor& input_keys,
const std::vector<Tensor>& input_values_soa) {
89 Tensor output_buf_indices, output_masks;
90 Insert(input_keys, input_values_soa, output_buf_indices, output_masks);
95 Tensor output_buf_indices, output_masks;
96 Activate(input_keys, output_buf_indices, output_masks);
101 Tensor output_buf_indices, output_masks;
102 Find(input_keys, output_buf_indices, output_masks);
108 Erase(input_keys, output_masks);
113 Tensor output_buf_indices;
115 return output_buf_indices;
119 const std::vector<Tensor>& input_values_soa,
120 Tensor& output_buf_indices,
122 bool is_activate_op) {
124 if (!is_activate_op) {
133 std::vector<const void*> input_values_ptrs;
134 for (
const auto& input_value : input_values_soa) {
135 input_values_ptrs.push_back(input_value.GetDataPtr());
138 device_hashmap_->Insert(
145 const Tensor& input_values,
146 Tensor& output_buf_indices,
148 Insert(input_keys, std::vector<Tensor>{input_values}, output_buf_indices,
153 const std::vector<Tensor>& input_values_soa,
154 Tensor& output_buf_indices,
160 if (new_size > capacity) {
163 InsertImpl(input_keys, input_values_soa, output_buf_indices, output_masks);
167 Tensor& output_buf_indices,
173 if (new_size > capacity) {
177 std::vector<Tensor> null_tensors_soa;
178 InsertImpl(input_keys, null_tensors_soa, output_buf_indices, output_masks,
183 Tensor& output_buf_indices,
192 device_hashmap_->Find(
205 device_hashmap_->Erase(input_keys.
GetDataPtr(),
210 int64_t
length = device_hashmap_->Size();
213 device_hashmap_->GetActiveIndices(
237 Tensor active_buf_indices_i32;
242 std::vector<Tensor> soa_active_values;
243 for (
const auto& value : values) {
244 soa_active_values.push_back(
245 value.IndexGet({active_indices}).To(device));
249 dtypes_value_, element_shapes_value_, device);
250 Tensor buf_indices, masks;
251 new_hashmap.
Insert(active_keys, soa_active_values, buf_indices, masks);
261 return device_hashmap_->GetBucketCount();
271 device_hashmap_->GetKeyBuffer().GetDataPtr(), key_dtype_,
272 device_hashmap_->GetKeyBuffer().GetBlob());
278 std::vector<Tensor> value_buffers = device_hashmap_->GetValueBuffers();
280 std::vector<Tensor> soa_value_tensor;
281 for (
size_t i = 0; i < element_shapes_value_.size(); ++i) {
282 SizeVector value_shape = element_shapes_value_[i];
285 Dtype value_dtype = dtypes_value_[i];
286 soa_value_tensor.push_back(
288 value_buffers[i].GetDataPtr(), value_dtype,
289 value_buffers[i].GetBlob()));
291 return soa_value_tensor;
297 if (i >= dtypes_value_.size()) {
299 dtypes_value_.size());
302 Tensor value_buffer = device_hashmap_->GetValueBuffer(i);
304 SizeVector value_shape = element_shapes_value_[i];
307 Dtype value_dtype = dtypes_value_[i];
314 return device_hashmap_->BucketSizes();
328 "Key element shape must contain at least 1 element, "
333 if (dtypes_value_.size() != element_shapes_value_.size()) {
335 "Size of value_dtype ({}) mismatches with size of "
336 "element_shapes_value ({}).",
337 dtypes_value_.size(), element_shapes_value_.size());
339 for (
const auto& value_dtype : dtypes_value_) {
344 for (
const auto& value_element_shape : element_shapes_value_) {
345 if (value_element_shape.NumElements() == 0) {
347 "Value element shape must contain at least 1 "
348 "element, but got 0.");
353 init_capacity, key_dtype_, key_element_shape_, dtypes_value_,
354 element_shapes_value_, device, backend);
358 int64_t key_len = input_keys.
GetLength();
366 const std::vector<Tensor>& input_values_soa)
const {
367 int64_t key_len = input_keys.
GetLength();
371 for (
size_t i = 0; i < input_values_soa.size(); ++i) {
372 Tensor input_value = input_values_soa[i];
373 if (input_value.
GetLength() != key_len) {
375 "Input number of values at {} mismatch with number of "
385 input_key_elem_shape.
erase(input_key_elem_shape.
begin());
387 int64_t input_key_elem_bytesize = input_key_elem_shape.
NumElements() *
389 int64_t stored_key_elem_bytesize =
391 if (input_key_elem_bytesize != stored_key_elem_bytesize) {
393 "Input key element bytesize ({}) mismatch with stored ({})",
394 input_key_elem_bytesize, stored_key_elem_bytesize);
399 const std::vector<Tensor>& input_values_soa)
const {
400 if (input_values_soa.size() != element_shapes_value_.size()) {
402 "Input number of value arrays ({}) mismatches with stored "
404 input_values_soa.size(), element_shapes_value_.size());
407 for (
size_t i = 0; i < input_values_soa.size(); ++i) {
408 Tensor input_value = input_values_soa[i];
410 input_value_i_elem_shape.
erase(input_value_i_elem_shape.
begin());
412 int64_t input_value_i_elem_bytesize =
416 int64_t stored_value_i_elem_bytesize =
417 element_shapes_value_[i].NumElements() *
418 dtypes_value_[i].ByteSize();
419 if (input_value_i_elem_bytesize != stored_value_i_elem_bytesize) {
421 "Input value[{}] element bytesize ({}) mismatch with "
423 i, input_value_i_elem_bytesize,
424 stored_value_i_elem_bytesize);
DtypeCode GetDtypeCode() const
void PrepareMasksOutput(Tensor &output_masks, int64_t length) const
std::vector< int64_t > BucketSizes() const
Return number of elements per bucket.
void PrepareIndicesOutput(Tensor &output_buf_indices, int64_t length) const
HashMap(int64_t init_capacity, const Dtype &key_dtype, const SizeVector &key_element_shape, const Dtype &value_dtype, const SizeVector &value_element_shapes, const Device &device, const HashBackendType &backend=HashBackendType::Default)
Initialize a hash map given a key and a value dtype and element shape.
std::pair< Tensor, Tensor > Activate(const Tensor &input_keys)
void InsertImpl(const Tensor &input_keys, const std::vector< Tensor > &input_values_soa, Tensor &output_buf_indices, Tensor &output_masks, bool is_activate_op=false)
std::pair< Tensor, Tensor > Find(const Tensor &input_keys)
void CheckKeyCompatibility(const Tensor &input_keys) const
HashMap To(const Device &device, bool copy=false) const
Convert the hash map to another device.
std::vector< Tensor > GetValueTensors() const
void CheckKeyValueLengthCompatibility(const Tensor &input_keys, const std::vector< Tensor > &input_values_soa) const
void Clear()
Clear stored map without reallocating the buffers.
int64_t GetBucketCount() const
Get the number of buckets of the internal hash map.
Tensor Erase(const Tensor &input_keys)
void CheckKeyLength(const Tensor &input_keys) const
static HashMap Load(const std::string &file_name)
void Save(const std::string &file_name)
void Init(int64_t init_capacity, const Device &device, const HashBackendType &backend)
int64_t GetCapacity() const
Get the capacity of the hash map.
std::pair< Tensor, Tensor > Insert(const Tensor &input_keys, const Tensor &input_values)
void Reserve(int64_t capacity)
Reserve the internal hash map with the given capacity by rehashing.
Tensor GetActiveIndices() const
Tensor GetKeyTensor() const
void CheckValueCompatibility(const std::vector< Tensor > &input_values_soa) const
float LoadFactor() const
Return size / bucket_count.
Device GetDevice() const override
Get the device of the hash map.
int64_t Size() const
Get the size (number of active entries) of the hash map.
Tensor GetValueTensor(size_t index=0) const
HashMap Clone() const
Clone the hash map with buffers.
int64_t NumElements() const
iterator erase(const_iterator CI)
iterator insert(iterator I, T &&Elt)
int64_t GetLength() const
Tensor IndexGet(const std::vector< Tensor > &index_tensors) const
Advanced indexing getter. This will always allocate a new Tensor.
Device GetDevice() const override
SizeVector GetShape() const
Tensor To(Dtype dtype, bool copy=false) const
std::shared_ptr< Blob > GetBlob() const
__host__ __device__ float length(float2 v)
Helper functions for the ml ops.
SizeVector DefaultStrides(const SizeVector &shape)
Compute default strides for a shape when a tensor is contiguous.
std::shared_ptr< DeviceHashBackend > CreateDeviceHashBackend(int64_t init_capacity, const Dtype &key_dtype, const SizeVector &key_element_shape, const std::vector< Dtype > &value_dtypes, const std::vector< SizeVector > &value_element_shapes, const Device &device, const HashBackendType &backend)
CLOUDVIEWER_HOST_DEVICE Pair< First, Second > make_pair(const First &_first, const Second &_second)
core::HashMap ReadHashMap(const std::string &file_name)
void WriteHashMap(const std::string &file_name, const core::HashMap &hashmap)
Generic file read and write utility for python interface.