19 class DeviceHashBackend;
27 const Dtype& key_dtype,
29 const Dtype& value_dtype,
38 const Dtype& key_dtype,
40 const std::vector<Dtype>& dtypes_value,
41 const std::vector<SizeVector>& element_shapes_value,
59 std::pair<Tensor, Tensor>
Insert(
const Tensor& input_keys,
60 const Tensor& input_values);
66 std::pair<Tensor, Tensor>
Insert(
68 const std::vector<Tensor>& input_values_soa);
81 std::pair<Tensor, Tensor>
Find(
const Tensor& input_keys);
97 const Tensor& input_values,
98 Tensor& output_buf_indices,
105 const std::vector<Tensor>& input_values_soa,
106 Tensor& output_buf_indices,
113 Tensor& output_buf_indices,
120 Tensor& output_buf_indices,
138 void Save(
const std::string& file_name);
151 int64_t
Size()
const;
185 return device_hashmap_;
189 void Init(int64_t init_capacity,
194 const std::vector<Tensor>& input_values_soa,
195 Tensor& output_buf_indices,
197 bool is_activate_op =
false);
202 const std::vector<Tensor>& input_values_soa)
const;
205 const std::vector<Tensor>& input_values_soa)
const;
213 std::shared_ptr<DeviceHashBackend> device_hashmap_;
218 std::vector<Dtype> dtypes_value_;
219 std::vector<SizeVector> element_shapes_value_;
void PrepareMasksOutput(Tensor &output_masks, int64_t length) const
std::vector< int64_t > BucketSizes() const
Return number of elements per bucket.
void PrepareIndicesOutput(Tensor &output_buf_indices, int64_t length) const
HashMap(int64_t init_capacity, const Dtype &key_dtype, const SizeVector &key_element_shape, const Dtype &value_dtype, const SizeVector &value_element_shapes, const Device &device, const HashBackendType &backend=HashBackendType::Default)
Initialize a hash map given a key and a value dtype and element shape.
std::pair< Tensor, Tensor > Activate(const Tensor &input_keys)
void InsertImpl(const Tensor &input_keys, const std::vector< Tensor > &input_values_soa, Tensor &output_buf_indices, Tensor &output_masks, bool is_activate_op=false)
std::pair< Tensor, Tensor > Find(const Tensor &input_keys)
void CheckKeyCompatibility(const Tensor &input_keys) const
HashMap To(const Device &device, bool copy=false) const
Convert the hash map to another device.
std::vector< Tensor > GetValueTensors() const
void CheckKeyValueLengthCompatibility(const Tensor &input_keys, const std::vector< Tensor > &input_values_soa) const
void Clear()
Clear stored map without reallocating the buffers.
int64_t GetBucketCount() const
Get the number of buckets of the internal hash map.
Tensor Erase(const Tensor &input_keys)
void CheckKeyLength(const Tensor &input_keys) const
static HashMap Load(const std::string &file_name)
void Save(const std::string &file_name)
void Init(int64_t init_capacity, const Device &device, const HashBackendType &backend)
std::shared_ptr< DeviceHashBackend > GetDeviceHashBackend() const
Return the implementation of the device hash backend.
int64_t GetCapacity() const
Get the capacity of the hash map.
std::pair< Tensor, Tensor > Insert(const Tensor &input_keys, const Tensor &input_values)
void Reserve(int64_t capacity)
Reserve the internal hash map with the given capacity by rehashing.
Tensor GetActiveIndices() const
Tensor GetKeyTensor() const
void CheckValueCompatibility(const std::vector< Tensor > &input_values_soa) const
float LoadFactor() const
Return size / bucket_count.
~HashMap()=default
Default destructor.
Device GetDevice() const override
Get the device of the hash map.
int64_t Size() const
Get the size (number of active entries) of the hash map.
Tensor GetValueTensor(size_t index=0) const
HashMap Clone() const
Clone the hash map with buffers.
std::pair< int64_t, std::vector< int64_t > > GetCommonValueSizeDivisor()
__host__ __device__ float length(float2 v)
Generic file read and write utility for python interface.