31 #ifndef FLANN_KMEANS_INDEX_H_
32 #define FLANN_KMEANS_INDEX_H_
65 (*this)[
"branching"] = branching;
67 (*this)[
"iterations"] = iterations;
69 (*this)[
"centers_init"] = centers_init;
71 (*this)[
"cb_index"] = cb_index;
82 template <
typename Distance>
147 branching_(other.branching_),
148 iterations_(other.iterations_),
149 centers_init_(other.centers_init_),
150 cb_index_(other.cb_index_),
151 memoryCounter_(other.memoryCounter_)
155 copyTree(root_, other.root_);
167 switch(centers_init_) {
178 throw FLANNException(
"Unknown algorithm for choosing initial centers.");
189 delete chooseCenters_;
218 size_t old_size =
size_;
226 for (
size_t i=0;i<
points.rows;++i) {
228 addPointToTree(root_, old_size + i,
dist);
233 template<
typename Archive>
246 if (Archive::is_loading::value) {
247 root_ =
new(pool_) Node();
251 if (Archive::is_loading::value) {
286 findNeighborsWithRemoved<true>(
result, vec, searchParams);
289 findNeighborsWithRemoved<false>(
result, vec, searchParams);
303 int numClusters = centers.
rows;
309 std::vector<NodePtr> clusters(numClusters);
311 int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance);
313 Logger::info(
"Clusters requested: %d, returning %d\n",numClusters, clusterCount);
315 for (
int i=0; i<clusterCount; ++i) {
317 for (
size_t j=0; j<
veclen_; ++j) {
318 centers[i][j] = center[j];
331 chooseCenters_->setDataSize(
veclen_);
337 std::vector<int> indices(
size_);
338 for (
size_t i=0; i<
size_; ++i) {
342 root_ =
new(pool_) Node();
343 computeNodeStatistics(root_, indices);
344 computeClustering(root_, &indices[0], (
int)
size_, branching_);
354 template<
typename Archive>
358 Index* obj =
static_cast<Index*
>(ar.getObject());
363 if (Archive::is_loading::value)
point = obj->points_[index];
365 friend struct serialization::access;
392 std::vector<Node*> childs;
396 std::vector<PointInfo>
points;
405 if (!childs.empty()) {
406 for (
size_t i=0; i<childs.size(); ++i) {
412 template<
typename Archive>
415 typedef KMeansIndex<Distance>
Index;
416 Index* obj =
static_cast<Index*
>(ar.getObject());
418 if (Archive::is_loading::value) {
428 if (Archive::is_saving::value) {
429 childs_size = childs.size();
433 if (childs_size==0) {
437 if (Archive::is_loading::value) {
438 childs.resize(childs_size);
440 for (
size_t i=0;i<childs_size;++i) {
441 if (Archive::is_loading::value) {
442 childs[i] =
new(obj->pool_)
Node();
448 friend struct serialization::access;
450 typedef Node* NodePtr;
455 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
463 if (root_) root_->~Node();
468 void copyTree(NodePtr& dst,
const NodePtr& src)
470 dst =
new(pool_)
Node();
473 dst->radius = src->radius;
474 dst->variance = src->variance;
475 dst->size = src->size;
477 if (src->childs.size()==0) {
478 dst->points = src->points;
481 dst->childs.resize(src->childs.size());
482 for (
size_t i=0;i<src->childs.size();++i) {
483 copyTree(dst->childs[i], src->childs[i]);
496 void computeNodeStatistics(NodePtr node,
const std::vector<int>& indices)
498 size_t size = indices.size();
504 for (
size_t i=0; i<
size; ++i) {
506 for (
size_t j=0; j<
veclen_; ++j) {
511 for (
size_t j=0; j<
veclen_; ++j) {
512 mean[j] *= div_factor;
517 for (
size_t i=0; i<
size; ++i) {
526 node->variance = variance;
527 node->radius = radius;
528 delete[] node->pivot;
544 void computeClustering(NodePtr node,
int* indices,
int indices_length,
int branching)
546 node->size = indices_length;
548 if (indices_length < branching) {
549 node->points.resize(indices_length);
550 for (
int i=0;i<indices_length;++i) {
551 node->points[i].index = indices[i];
552 node->points[i].point =
points_[indices[i]];
554 node->childs.clear();
558 std::vector<int> centers_idx(branching);
560 (*chooseCenters_)(branching, indices, indices_length, ¢ers_idx[0], centers_length);
562 if (centers_length<branching) {
563 node->points.resize(indices_length);
564 for (
int i=0;i<indices_length;++i) {
565 node->points[i].index = indices[i];
566 node->points[i].point =
points_[indices[i]];
568 node->childs.clear();
573 Matrix<double> dcenters(
new double[branching*
veclen_],branching,
veclen_);
574 for (
int i=0; i<centers_length; ++i) {
576 for (
size_t k=0; k<
veclen_; ++k) {
577 dcenters[i][k] = double(vec[k]);
581 std::vector<DistanceType> radiuses(branching,0);
582 std::vector<int>
count(branching,0);
585 std::vector<int> belongs_to(indices_length);
586 for (
int i=0; i<indices_length; ++i) {
590 for (
int j=1; j<branching; ++j) {
592 if (sq_dist>new_sq_dist) {
594 sq_dist = new_sq_dist;
597 if (sq_dist>radiuses[belongs_to[i]]) {
598 radiuses[belongs_to[i]] = sq_dist;
600 count[belongs_to[i]]++;
603 bool converged =
false;
605 while (!converged && iteration<iterations_) {
610 for (
int i=0; i<branching; ++i) {
611 memset(dcenters[i],0,
sizeof(
double)*
veclen_);
614 for (
int i=0; i<indices_length; ++i) {
616 double* center = dcenters[belongs_to[i]];
617 for (
size_t k=0; k<
veclen_; ++k) {
621 for (
int i=0; i<branching; ++i) {
623 double div_factor = 1.0/cnt;
624 for (
size_t k=0; k<
veclen_; ++k) {
625 dcenters[i][k] *= div_factor;
630 for (
int i=0; i<indices_length; ++i) {
632 int new_centroid = 0;
633 for (
int j=1; j<branching; ++j) {
635 if (sq_dist>new_sq_dist) {
637 sq_dist = new_sq_dist;
640 if (sq_dist>radiuses[new_centroid]) {
641 radiuses[new_centroid] = sq_dist;
643 if (new_centroid != belongs_to[i]) {
644 count[belongs_to[i]]--;
645 count[new_centroid]++;
646 belongs_to[i] = new_centroid;
652 for (
int i=0; i<branching; ++i) {
656 int j = (i+1)%branching;
657 while (
count[j]<=1) {
661 for (
int k=0; k<indices_length; ++k) {
662 if (belongs_to[k]==j) {
675 std::vector<DistanceType*> centers(branching);
677 for (
int i=0; i<branching; ++i) {
680 for (
size_t k=0; k<
veclen_; ++k) {
687 node->childs.resize(branching);
690 for (
int c=0; c<branching; ++c) {
694 for (
int i=0; i<indices_length; ++i) {
695 if (belongs_to[i]==c) {
698 std::swap(belongs_to[i],belongs_to[end]);
704 node->childs[c] =
new(pool_)
Node();
705 node->childs[c]->radius = radiuses[c];
706 node->childs[c]->pivot = centers[c];
707 node->childs[c]->variance = variance;
708 computeClustering(node->childs[c],indices+start, end-start, branching);
712 delete[] dcenters.ptr();
716 template<
bool with_removed>
717 void findNeighborsWithRemoved(ResultSet<DistanceType>&
result,
const ElementType* vec,
const SearchParams& searchParams)
const
720 int maxChecks = searchParams.checks;
723 findExactNN<with_removed>(root_,
result, vec);
727 Heap<BranchSt>* heap =
new Heap<BranchSt>((
int)
size_);
730 findNN<with_removed>(root_,
result, vec, checks, maxChecks, heap);
733 while (heap->popMin(branch) && (checks<maxChecks || !
result.full())) {
734 NodePtr node = branch.node;
735 findNN<with_removed>(node,
result, vec, checks, maxChecks, heap);
756 template<
bool with_removed>
757 void findNN(NodePtr node, ResultSet<DistanceType>&
result,
const ElementType* vec,
int& checks,
int maxChecks,
758 Heap<BranchSt>* heap)
const
770 if ((val>0)&&(val2>0)) {
775 if (node->childs.empty()) {
776 if (checks>=maxChecks) {
777 if (
result.full())
return;
779 for (
int i=0; i<node->size; ++i) {
780 PointInfo& point_info = node->points[i];
781 int index = point_info.index;
791 int closest_center = exploreNodeBranches(node, vec, heap);
792 findNN<with_removed>(node->childs[closest_center],
result,vec, checks, maxChecks, heap);
804 int exploreNodeBranches(NodePtr node,
const ElementType* q, Heap<BranchSt>* heap)
const
806 std::vector<DistanceType> domain_distances(branching_);
808 domain_distances[best_index] =
distance_(q, node->childs[best_index]->pivot,
veclen_);
809 for (
int i=1; i<branching_; ++i) {
811 if (domain_distances[i]<domain_distances[best_index]) {
817 for (
int i=0; i<branching_; ++i) {
818 if (i != best_index) {
819 domain_distances[i] -= cb_index_*node->childs[i]->variance;
825 heap->insert(BranchSt(node->childs[i],domain_distances[i]));
836 template<
bool with_removed>
837 void findExactNN(NodePtr node, ResultSet<DistanceType>&
result,
const ElementType* vec)
const
849 if ((val>0)&&(val2>0)) {
854 if (node->childs.empty()) {
855 for (
int i=0; i<node->size; ++i) {
856 PointInfo& point_info = node->points[i];
857 int index = point_info.index;
866 std::vector<int> sort_indices(branching_);
867 getCenterOrdering(node, vec, sort_indices);
869 for (
int i=0; i<branching_; ++i) {
870 findExactNN<with_removed>(node->childs[sort_indices[i]],
result,vec);
882 void getCenterOrdering(NodePtr node,
const ElementType* q, std::vector<int>& sort_indices)
const
884 std::vector<DistanceType> domain_distances(branching_);
885 for (
int i=0; i<branching_; ++i) {
889 while (domain_distances[j]<
dist && j<i) j++;
890 for (
int k=i; k>j; --k) {
891 domain_distances[k] = domain_distances[k-1];
892 sort_indices[k] = sort_indices[k-1];
894 domain_distances[j] =
dist;
909 for (
int i=0; i<
veclen_; ++i) {
911 sum += t*(q[i]-(c[i]+p[i])/2);
928 int getMinVarianceClusters(NodePtr root, std::vector<NodePtr>& clusters,
int clusters_length,
DistanceType& varianceValue)
const
930 int clusterCount = 1;
935 while (clusterCount<clusters_length) {
939 for (
int i=0; i<clusterCount; ++i) {
940 if (!clusters[i]->childs.empty()) {
942 DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size;
944 for (
int j=0; j<branching_; ++j) {
945 variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size;
947 if (variance<minVariance) {
948 minVariance = variance;
954 if (splitIndex==-1)
break;
955 if ( (branching_+clusterCount-1) > clusters_length)
break;
957 meanVariance = minVariance;
960 NodePtr toSplit = clusters[splitIndex];
961 clusters[splitIndex] = toSplit->childs[0];
962 for (
int i=1; i<branching_; ++i) {
963 clusters[clusterCount++] = toSplit->childs[i];
967 varianceValue = meanVariance/root->size;
971 void addPointToTree(NodePtr node,
size_t index,
DistanceType dist_to_pivot)
974 if (dist_to_pivot>node->radius) {
975 node->radius = dist_to_pivot;
978 node->variance = (node->size*node->variance+dist_to_pivot)/(node->size+1);
981 if (node->childs.empty()) {
982 PointInfo point_info;
983 point_info.index = index;
984 point_info.point =
point;
985 node->points.push_back(point_info);
987 std::vector<int> indices(node->points.size());
988 for (
size_t i=0;i<node->points.size();++i) {
989 indices[i] = node->points[i].index;
991 computeNodeStatistics(node, indices);
992 if (indices.size()>=
size_t(branching_)) {
993 computeClustering(node, &indices[0], indices.size(), branching_);
1000 for (
size_t i=1;i<size_t(branching_);++i) {
1002 if (crt_dist<
dist) {
1007 addPointToTree(node->childs[closest], index,
dist);
1014 std::swap(branching_, other.branching_);
1015 std::swap(iterations_, other.iterations_);
1016 std::swap(centers_init_, other.centers_init_);
1020 std::swap(memoryCounter_, other.memoryCounter_);
1021 std::swap(chooseCenters_, other.chooseCenters_);
1051 PooledAllocator pool_;
1061 CenterChooser<Distance>* chooseCenters_;
double Distance(const Point3D< Real > &p1, const Point3D< Real > &p2)
cmdLineReadable * params[]
Generic handle to any of the 8 types of E57 element objects.
bool test(size_t index) const
int getClusterCenters(Matrix< DistanceType > ¢ers)
void addPoints(const Matrix< ElementType > &points, float rebuild_threshold=2)
Incrementally add points to the index.
KMeansIndex & operator=(KMeansIndex other)
NNIndex< Distance > BaseClass
void serialize(Archive &ar)
Distance::ResultType DistanceType
BaseClass * clone() const
KMeansIndex(const Matrix< ElementType > &inputData, const IndexParams ¶ms=KMeansIndexParams(), Distance d=Distance())
KMeansIndex(const KMeansIndex &other)
void saveIndex(FILE *stream)
KMeansIndex(const IndexParams ¶ms=KMeansIndexParams(), Distance d=Distance())
void loadIndex(FILE *stream)
flann_algorithm_t getType() const
virtual void buildIndex()
Distance::ElementType ElementType
void set_cb_index(float index)
bool needs_vector_space_distance
void findNeighbors(ResultSet< DistanceType > &result, const ElementType *vec, const SearchParams &searchParams) const
static int info(const char *fmt,...)
std::vector< ElementType * > points_
DynamicBitset removed_points_
void setDataset(const Matrix< ElementType > &dataset)
virtual void buildIndex()
void extendDataset(const Matrix< ElementType > &new_points)
IndexParams index_params_
static double dist(double x1, double y1, double x2, double y2)
void swap(optional< T > &x, optional< T > &y) noexcept(noexcept(x.swap(y)))
const binary_object make_binary_object(void *t, size_t size)
T get_param(const IndexParams ¶ms, std::string name, const T &default_value)
std::map< std::string, any > IndexParams
void swap(cloudViewer::core::SmallVectorImpl< T > &LHS, cloudViewer::core::SmallVectorImpl< T > &RHS)
Implement std::swap in terms of SmallVector swap.
#define USING_BASECLASS_SYMBOLS
KMeansIndexParams(int branching=32, int iterations=11, flann_centers_init_t centers_init=FLANN_CENTERS_RANDOM, float cb_index=0.2)