ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
kmeans_index.h
Go to the documentation of this file.
1 /***********************************************************************
2  * Software License Agreement (BSD License)
3  *
4  * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5  * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6  *
7  * THE BSD LICENSE
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  * notice, this list of conditions and the following disclaimer in the
17  * documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *************************************************************************/
30 
31 #ifndef FLANN_KMEANS_INDEX_H_
32 #define FLANN_KMEANS_INDEX_H_
33 
34 #include <algorithm>
35 #include <string>
36 #include <map>
37 #include <cassert>
38 #include <limits>
39 #include <cmath>
40 
41 #include "FLANN/general.h"
43 #include "FLANN/algorithms/dist.h"
45 #include "FLANN/util/matrix.h"
46 #include "FLANN/util/result_set.h"
47 #include "FLANN/util/heap.h"
48 #include "FLANN/util/allocator.h"
49 #include "FLANN/util/random.h"
50 #include "FLANN/util/saving.h"
51 #include "FLANN/util/logger.h"
52 
53 
54 
55 namespace flann
56 {
57 
59 {
60  KMeansIndexParams(int branching = 32, int iterations = 11,
61  flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM, float cb_index = 0.2 )
62  {
63  (*this)["algorithm"] = FLANN_INDEX_KMEANS;
64  // branching factor
65  (*this)["branching"] = branching;
66  // max iterations to perform in one kmeans clustering (kmeans tree)
67  (*this)["iterations"] = iterations;
68  // algorithm used for picking the initial cluster centers for kmeans tree
69  (*this)["centers_init"] = centers_init;
70  // cluster boundary index. Used when searching the kmeans tree
71  (*this)["cb_index"] = cb_index;
72  }
73 };
74 
75 
82 template <typename Distance>
83 class KMeansIndex : public NNIndex<Distance>
84 {
85 public:
86  typedef typename Distance::ElementType ElementType;
87  typedef typename Distance::ResultType DistanceType;
88 
90 
92 
93 
94 
96  {
97  return FLANN_INDEX_KMEANS;
98  }
99 
108  Distance d = Distance())
109  : BaseClass(params,d), root_(NULL), memoryCounter_(0)
110  {
111  branching_ = get_param(params,"branching",32);
112  iterations_ = get_param(params,"iterations",11);
113  if (iterations_<0) {
114  iterations_ = (std::numeric_limits<int>::max)();
115  }
116  centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
117  cb_index_ = get_param(params,"cb_index",0.4f);
118 
120  setDataset(inputData);
121  }
122 
123 
132  : BaseClass(params, d), root_(NULL), memoryCounter_(0)
133  {
134  branching_ = get_param(params,"branching",32);
135  iterations_ = get_param(params,"iterations",11);
136  if (iterations_<0) {
137  iterations_ = (std::numeric_limits<int>::max)();
138  }
139  centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
140  cb_index_ = get_param(params,"cb_index",0.4f);
141 
143  }
144 
145 
146  KMeansIndex(const KMeansIndex& other) : BaseClass(other),
147  branching_(other.branching_),
148  iterations_(other.iterations_),
149  centers_init_(other.centers_init_),
150  cb_index_(other.cb_index_),
151  memoryCounter_(other.memoryCounter_)
152  {
154 
155  copyTree(root_, other.root_);
156  }
157 
159  {
160  this->swap(other);
161  return *this;
162  }
163 
164 
166  {
167  switch(centers_init_) {
169  chooseCenters_ = new RandomCenterChooser<Distance>(distance_, points_);
170  break;
172  chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_, points_);
173  break;
175  chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_, points_);
176  break;
177  default:
178  throw FLANNException("Unknown algorithm for choosing initial centers.");
179  }
180  }
181 
187  virtual ~KMeansIndex()
188  {
189  delete chooseCenters_;
190  freeIndex();
191  }
192 
193  BaseClass* clone() const
194  {
195  return new KMeansIndex(*this);
196  }
197 
198 
199  void set_cb_index( float index)
200  {
201  cb_index_ = index;
202  }
203 
208  int usedMemory() const
209  {
210  return pool_.usedMemory+pool_.wastedMemory+memoryCounter_;
211  }
212 
213  using BaseClass::buildIndex;
214 
215  void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
216  {
217  assert(points.cols==veclen_);
218  size_t old_size = size_;
219 
221 
222  if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
223  buildIndex();
224  }
225  else {
226  for (size_t i=0;i<points.rows;++i) {
227  DistanceType dist = distance_(root_->pivot, points[i], veclen_);
228  addPointToTree(root_, old_size + i, dist);
229  }
230  }
231  }
232 
233  template<typename Archive>
234  void serialize(Archive& ar)
235  {
236  ar.setObject(this);
237 
238  ar & *static_cast<NNIndex<Distance>*>(this);
239 
240  ar & branching_;
241  ar & iterations_;
242  ar & memoryCounter_;
243  ar & cb_index_;
244  ar & centers_init_;
245 
246  if (Archive::is_loading::value) {
247  root_ = new(pool_) Node();
248  }
249  ar & *root_;
250 
251  if (Archive::is_loading::value) {
252  index_params_["algorithm"] = getType();
253  index_params_["branching"] = branching_;
254  index_params_["iterations"] = iterations_;
255  index_params_["centers_init"] = centers_init_;
256  index_params_["cb_index"] = cb_index_;
257  }
258  }
259 
260  void saveIndex(FILE* stream)
261  {
262  serialization::SaveArchive sa(stream);
263  sa & *this;
264  }
265 
266  void loadIndex(FILE* stream)
267  {
268  freeIndex();
269  serialization::LoadArchive la(stream);
270  la & *this;
271  }
272 
283  void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
284  {
285  if (removed_) {
286  findNeighborsWithRemoved<true>(result, vec, searchParams);
287  }
288  else {
289  findNeighborsWithRemoved<false>(result, vec, searchParams);
290  }
291 
292  }
293 
302  {
303  int numClusters = centers.rows;
304  if (numClusters<1) {
305  throw FLANNException("Number of clusters must be at least 1");
306  }
307 
308  DistanceType variance;
309  std::vector<NodePtr> clusters(numClusters);
310 
311  int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance);
312 
313  Logger::info("Clusters requested: %d, returning %d\n",numClusters, clusterCount);
314 
315  for (int i=0; i<clusterCount; ++i) {
316  DistanceType* center = clusters[i]->pivot;
317  for (size_t j=0; j<veclen_; ++j) {
318  centers[i][j] = center[j];
319  }
320  }
321 
322  return clusterCount;
323  }
324 
325 protected:
330  {
331  chooseCenters_->setDataSize(veclen_);
332 
333  if (branching_<2) {
334  throw FLANNException("Branching factor must be at least 2");
335  }
336 
337  std::vector<int> indices(size_);
338  for (size_t i=0; i<size_; ++i) {
339  indices[i] = int(i);
340  }
341 
342  root_ = new(pool_) Node();
343  computeNodeStatistics(root_, indices);
344  computeClustering(root_, &indices[0], (int)size_, branching_);
345  }
346 
347 private:
348 
349  struct PointInfo
350  {
351  size_t index;
353  private:
354  template<typename Archive>
355  void serialize(Archive& ar)
356  {
358  Index* obj = static_cast<Index*>(ar.getObject());
359 
360  ar & index;
361 // ar & point;
362 
363  if (Archive::is_loading::value) point = obj->points_[index];
364  }
365  friend struct serialization::access;
366  };
367 
371  struct Node
372  {
376  DistanceType* pivot=NULL;
380  DistanceType radius;
384  DistanceType variance;
388  int size;
392  std::vector<Node*> childs;
396  std::vector<PointInfo> points;
400 // int level;
401 
402  ~Node()
403  {
404  delete[] pivot;
405  if (!childs.empty()) {
406  for (size_t i=0; i<childs.size(); ++i) {
407  childs[i]->~Node();
408  }
409  }
410  }
411 
412  template<typename Archive>
413  void serialize(Archive& ar)
414  {
415  typedef KMeansIndex<Distance> Index;
416  Index* obj = static_cast<Index*>(ar.getObject());
417 
418  if (Archive::is_loading::value) {
419  delete[] pivot;
420  pivot = new DistanceType[obj->veclen_];
421  }
422  ar & serialization::make_binary_object(pivot, obj->veclen_*sizeof(DistanceType));
423  ar & radius;
424  ar & variance;
425  ar & size;
426 
427  size_t childs_size;
428  if (Archive::is_saving::value) {
429  childs_size = childs.size();
430  }
431  ar & childs_size;
432 
433  if (childs_size==0) {
434  ar & points;
435  }
436  else {
437  if (Archive::is_loading::value) {
438  childs.resize(childs_size);
439  }
440  for (size_t i=0;i<childs_size;++i) {
441  if (Archive::is_loading::value) {
442  childs[i] = new(obj->pool_) Node();
443  }
444  ar & *childs[i];
445  }
446  }
447  }
448  friend struct serialization::access;
449  };
450  typedef Node* NodePtr;
451 
455  typedef BranchStruct<NodePtr, DistanceType> BranchSt;
456 
457 
461  void freeIndex()
462  {
463  if (root_) root_->~Node();
464  root_ = NULL;
465  pool_.free();
466  }
467 
468  void copyTree(NodePtr& dst, const NodePtr& src)
469  {
470  dst = new(pool_) Node();
471  dst->pivot = new DistanceType[veclen_];
472  std::copy(src->pivot, src->pivot+veclen_, dst->pivot);
473  dst->radius = src->radius;
474  dst->variance = src->variance;
475  dst->size = src->size;
476 
477  if (src->childs.size()==0) {
478  dst->points = src->points;
479  }
480  else {
481  dst->childs.resize(src->childs.size());
482  for (size_t i=0;i<src->childs.size();++i) {
483  copyTree(dst->childs[i], src->childs[i]);
484  }
485  }
486  }
487 
488 
496  void computeNodeStatistics(NodePtr node, const std::vector<int>& indices)
497  {
498  size_t size = indices.size();
499 
500  DistanceType* mean = new DistanceType[veclen_];
501  memoryCounter_ += int(veclen_*sizeof(DistanceType));
502  memset(mean,0,veclen_*sizeof(DistanceType));
503 
504  for (size_t i=0; i<size; ++i) {
505  ElementType* vec = points_[indices[i]];
506  for (size_t j=0; j<veclen_; ++j) {
507  mean[j] += vec[j];
508  }
509  }
510  DistanceType div_factor = DistanceType(1)/size;
511  for (size_t j=0; j<veclen_; ++j) {
512  mean[j] *= div_factor;
513  }
514 
515  DistanceType radius = 0;
516  DistanceType variance = 0;
517  for (size_t i=0; i<size; ++i) {
518  DistanceType dist = distance_(mean, points_[indices[i]], veclen_);
519  if (dist>radius) {
520  radius = dist;
521  }
522  variance += dist;
523  }
524  variance /= size;
525 
526  node->variance = variance;
527  node->radius = radius;
528  delete[] node->pivot;
529  node->pivot = mean;
530  }
531 
532 
544  void computeClustering(NodePtr node, int* indices, int indices_length, int branching)
545  {
546  node->size = indices_length;
547 
548  if (indices_length < branching) {
549  node->points.resize(indices_length);
550  for (int i=0;i<indices_length;++i) {
551  node->points[i].index = indices[i];
552  node->points[i].point = points_[indices[i]];
553  }
554  node->childs.clear();
555  return;
556  }
557 
558  std::vector<int> centers_idx(branching);
559  int centers_length;
560  (*chooseCenters_)(branching, indices, indices_length, &centers_idx[0], centers_length);
561 
562  if (centers_length<branching) {
563  node->points.resize(indices_length);
564  for (int i=0;i<indices_length;++i) {
565  node->points[i].index = indices[i];
566  node->points[i].point = points_[indices[i]];
567  }
568  node->childs.clear();
569  return;
570  }
571 
572 
573  Matrix<double> dcenters(new double[branching*veclen_],branching,veclen_);
574  for (int i=0; i<centers_length; ++i) {
575  ElementType* vec = points_[centers_idx[i]];
576  for (size_t k=0; k<veclen_; ++k) {
577  dcenters[i][k] = double(vec[k]);
578  }
579  }
580 
581  std::vector<DistanceType> radiuses(branching,0);
582  std::vector<int> count(branching,0);
583 
584  // assign points to clusters
585  std::vector<int> belongs_to(indices_length);
586  for (int i=0; i<indices_length; ++i) {
587 
588  DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
589  belongs_to[i] = 0;
590  for (int j=1; j<branching; ++j) {
591  DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
592  if (sq_dist>new_sq_dist) {
593  belongs_to[i] = j;
594  sq_dist = new_sq_dist;
595  }
596  }
597  if (sq_dist>radiuses[belongs_to[i]]) {
598  radiuses[belongs_to[i]] = sq_dist;
599  }
600  count[belongs_to[i]]++;
601  }
602 
603  bool converged = false;
604  int iteration = 0;
605  while (!converged && iteration<iterations_) {
606  converged = true;
607  iteration++;
608 
609  // compute the new cluster centers
610  for (int i=0; i<branching; ++i) {
611  memset(dcenters[i],0,sizeof(double)*veclen_);
612  radiuses[i] = 0;
613  }
614  for (int i=0; i<indices_length; ++i) {
615  ElementType* vec = points_[indices[i]];
616  double* center = dcenters[belongs_to[i]];
617  for (size_t k=0; k<veclen_; ++k) {
618  center[k] += vec[k];
619  }
620  }
621  for (int i=0; i<branching; ++i) {
622  int cnt = count[i];
623  double div_factor = 1.0/cnt;
624  for (size_t k=0; k<veclen_; ++k) {
625  dcenters[i][k] *= div_factor;
626  }
627  }
628 
629  // reassign points to clusters
630  for (int i=0; i<indices_length; ++i) {
631  DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
632  int new_centroid = 0;
633  for (int j=1; j<branching; ++j) {
634  DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
635  if (sq_dist>new_sq_dist) {
636  new_centroid = j;
637  sq_dist = new_sq_dist;
638  }
639  }
640  if (sq_dist>radiuses[new_centroid]) {
641  radiuses[new_centroid] = sq_dist;
642  }
643  if (new_centroid != belongs_to[i]) {
644  count[belongs_to[i]]--;
645  count[new_centroid]++;
646  belongs_to[i] = new_centroid;
647 
648  converged = false;
649  }
650  }
651 
652  for (int i=0; i<branching; ++i) {
653  // if one cluster converges to an empty cluster,
654  // move an element into that cluster
655  if (count[i]==0) {
656  int j = (i+1)%branching;
657  while (count[j]<=1) {
658  j = (j+1)%branching;
659  }
660 
661  for (int k=0; k<indices_length; ++k) {
662  if (belongs_to[k]==j) {
663  belongs_to[k] = i;
664  count[j]--;
665  count[i]++;
666  break;
667  }
668  }
669  converged = false;
670  }
671  }
672 
673  }
674 
675  std::vector<DistanceType*> centers(branching);
676 
677  for (int i=0; i<branching; ++i) {
678  centers[i] = new DistanceType[veclen_];
679  memoryCounter_ += veclen_*sizeof(DistanceType);
680  for (size_t k=0; k<veclen_; ++k) {
681  centers[i][k] = (DistanceType)dcenters[i][k];
682  }
683  }
684 
685 
686  // compute kmeans clustering for each of the resulting clusters
687  node->childs.resize(branching);
688  int start = 0;
689  int end = start;
690  for (int c=0; c<branching; ++c) {
691  int s = count[c];
692 
693  DistanceType variance = 0;
694  for (int i=0; i<indices_length; ++i) {
695  if (belongs_to[i]==c) {
696  variance += distance_(centers[c], points_[indices[i]], veclen_);
697  std::swap(indices[i],indices[end]);
698  std::swap(belongs_to[i],belongs_to[end]);
699  end++;
700  }
701  }
702  variance /= s;
703 
704  node->childs[c] = new(pool_) Node();
705  node->childs[c]->radius = radiuses[c];
706  node->childs[c]->pivot = centers[c];
707  node->childs[c]->variance = variance;
708  computeClustering(node->childs[c],indices+start, end-start, branching);
709  start=end;
710  }
711 
712  delete[] dcenters.ptr();
713  }
714 
715 
716  template<bool with_removed>
717  void findNeighborsWithRemoved(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
718  {
719 
720  int maxChecks = searchParams.checks;
721 
722  if (maxChecks==FLANN_CHECKS_UNLIMITED) {
723  findExactNN<with_removed>(root_, result, vec);
724  }
725  else {
726  // Priority queue storing intermediate branches in the best-bin-first search
727  Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
728 
729  int checks = 0;
730  findNN<with_removed>(root_, result, vec, checks, maxChecks, heap);
731 
732  BranchSt branch;
733  while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
734  NodePtr node = branch.node;
735  findNN<with_removed>(node, result, vec, checks, maxChecks, heap);
736  }
737 
738  delete heap;
739  }
740 
741  }
742 
743 
756  template<bool with_removed>
757  void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
758  Heap<BranchSt>* heap) const
759  {
760  // Ignore those clusters that are too far away
761  {
762  DistanceType bsq = distance_(vec, node->pivot, veclen_);
763  DistanceType rsq = node->radius;
764  DistanceType wsq = result.worstDist();
765 
766  DistanceType val = bsq-rsq-wsq;
767  DistanceType val2 = val*val-4*rsq*wsq;
768 
769  //if (val>0) {
770  if ((val>0)&&(val2>0)) {
771  return;
772  }
773  }
774 
775  if (node->childs.empty()) {
776  if (checks>=maxChecks) {
777  if (result.full()) return;
778  }
779  for (int i=0; i<node->size; ++i) {
780  PointInfo& point_info = node->points[i];
781  int index = point_info.index;
782  if (with_removed) {
783  if (removed_points_.test(index)) continue;
784  }
785  DistanceType dist = distance_(point_info.point, vec, veclen_);
786  result.addPoint(dist, index);
787  ++checks;
788  }
789  }
790  else {
791  int closest_center = exploreNodeBranches(node, vec, heap);
792  findNN<with_removed>(node->childs[closest_center],result,vec, checks, maxChecks, heap);
793  }
794  }
795 
804  int exploreNodeBranches(NodePtr node, const ElementType* q, Heap<BranchSt>* heap) const
805  {
806  std::vector<DistanceType> domain_distances(branching_);
807  int best_index = 0;
808  domain_distances[best_index] = distance_(q, node->childs[best_index]->pivot, veclen_);
809  for (int i=1; i<branching_; ++i) {
810  domain_distances[i] = distance_(q, node->childs[i]->pivot, veclen_);
811  if (domain_distances[i]<domain_distances[best_index]) {
812  best_index = i;
813  }
814  }
815 
816  // float* best_center = node->childs[best_index]->pivot;
817  for (int i=0; i<branching_; ++i) {
818  if (i != best_index) {
819  domain_distances[i] -= cb_index_*node->childs[i]->variance;
820 
821  // float dist_to_border = getDistanceToBorder(node.childs[i].pivot,best_center,q);
822  // if (domain_distances[i]<dist_to_border) {
823  // domain_distances[i] = dist_to_border;
824  // }
825  heap->insert(BranchSt(node->childs[i],domain_distances[i]));
826  }
827  }
828 
829  return best_index;
830  }
831 
832 
836  template<bool with_removed>
837  void findExactNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec) const
838  {
839  // Ignore those clusters that are too far away
840  {
841  DistanceType bsq = distance_(vec, node->pivot, veclen_);
842  DistanceType rsq = node->radius;
843  DistanceType wsq = result.worstDist();
844 
845  DistanceType val = bsq-rsq-wsq;
846  DistanceType val2 = val*val-4*rsq*wsq;
847 
848  // if (val>0) {
849  if ((val>0)&&(val2>0)) {
850  return;
851  }
852  }
853 
854  if (node->childs.empty()) {
855  for (int i=0; i<node->size; ++i) {
856  PointInfo& point_info = node->points[i];
857  int index = point_info.index;
858  if (with_removed) {
859  if (removed_points_.test(index)) continue;
860  }
861  DistanceType dist = distance_(point_info.point, vec, veclen_);
862  result.addPoint(dist, index);
863  }
864  }
865  else {
866  std::vector<int> sort_indices(branching_);
867  getCenterOrdering(node, vec, sort_indices);
868 
869  for (int i=0; i<branching_; ++i) {
870  findExactNN<with_removed>(node->childs[sort_indices[i]],result,vec);
871  }
872 
873  }
874  }
875 
876 
882  void getCenterOrdering(NodePtr node, const ElementType* q, std::vector<int>& sort_indices) const
883  {
884  std::vector<DistanceType> domain_distances(branching_);
885  for (int i=0; i<branching_; ++i) {
886  DistanceType dist = distance_(q, node->childs[i]->pivot, veclen_);
887 
888  int j=0;
889  while (domain_distances[j]<dist && j<i) j++;
890  for (int k=i; k>j; --k) {
891  domain_distances[k] = domain_distances[k-1];
892  sort_indices[k] = sort_indices[k-1];
893  }
894  domain_distances[j] = dist;
895  sort_indices[j] = i;
896  }
897  }
898 
904  DistanceType getDistanceToBorder(DistanceType* p, DistanceType* c, DistanceType* q) const
905  {
906  DistanceType sum = 0;
907  DistanceType sum2 = 0;
908 
909  for (int i=0; i<veclen_; ++i) {
910  DistanceType t = c[i]-p[i];
911  sum += t*(q[i]-(c[i]+p[i])/2);
912  sum2 += t*t;
913  }
914 
915  return sum*sum/sum2;
916  }
917 
918 
928  int getMinVarianceClusters(NodePtr root, std::vector<NodePtr>& clusters, int clusters_length, DistanceType& varianceValue) const
929  {
930  int clusterCount = 1;
931  clusters[0] = root;
932 
933  DistanceType meanVariance = root->variance*root->size;
934 
935  while (clusterCount<clusters_length) {
937  int splitIndex = -1;
938 
939  for (int i=0; i<clusterCount; ++i) {
940  if (!clusters[i]->childs.empty()) {
941 
942  DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size;
943 
944  for (int j=0; j<branching_; ++j) {
945  variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size;
946  }
947  if (variance<minVariance) {
948  minVariance = variance;
949  splitIndex = i;
950  }
951  }
952  }
953 
954  if (splitIndex==-1) break;
955  if ( (branching_+clusterCount-1) > clusters_length) break;
956 
957  meanVariance = minVariance;
958 
959  // split node
960  NodePtr toSplit = clusters[splitIndex];
961  clusters[splitIndex] = toSplit->childs[0];
962  for (int i=1; i<branching_; ++i) {
963  clusters[clusterCount++] = toSplit->childs[i];
964  }
965  }
966 
967  varianceValue = meanVariance/root->size;
968  return clusterCount;
969  }
970 
971  void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot)
972  {
973  ElementType* point = points_[index];
974  if (dist_to_pivot>node->radius) {
975  node->radius = dist_to_pivot;
976  }
977  // if radius changed above, the variance will be an approximation
978  node->variance = (node->size*node->variance+dist_to_pivot)/(node->size+1);
979  node->size++;
980 
981  if (node->childs.empty()) { // leaf node
982  PointInfo point_info;
983  point_info.index = index;
984  point_info.point = point;
985  node->points.push_back(point_info);
986 
987  std::vector<int> indices(node->points.size());
988  for (size_t i=0;i<node->points.size();++i) {
989  indices[i] = node->points[i].index;
990  }
991  computeNodeStatistics(node, indices);
992  if (indices.size()>=size_t(branching_)) {
993  computeClustering(node, &indices[0], indices.size(), branching_);
994  }
995  }
996  else {
997  // find the closest child
998  int closest = 0;
999  DistanceType dist = distance_(node->childs[closest]->pivot, point, veclen_);
1000  for (size_t i=1;i<size_t(branching_);++i) {
1001  DistanceType crt_dist = distance_(node->childs[i]->pivot, point, veclen_);
1002  if (crt_dist<dist) {
1003  dist = crt_dist;
1004  closest = i;
1005  }
1006  }
1007  addPointToTree(node->childs[closest], index, dist);
1008  }
1009  }
1010 
1011 
1012  void swap(KMeansIndex& other)
1013  {
1014  std::swap(branching_, other.branching_);
1015  std::swap(iterations_, other.iterations_);
1016  std::swap(centers_init_, other.centers_init_);
1017  std::swap(cb_index_, other.cb_index_);
1018  std::swap(root_, other.root_);
1019  std::swap(pool_, other.pool_);
1020  std::swap(memoryCounter_, other.memoryCounter_);
1021  std::swap(chooseCenters_, other.chooseCenters_);
1022  }
1023 
1024 
1025 private:
1027  int branching_;
1028 
1030  int iterations_;
1031 
1033  flann_centers_init_t centers_init_;
1034 
1041  float cb_index_;
1042 
1046  NodePtr root_;
1047 
1051  PooledAllocator pool_;
1052 
1056  int memoryCounter_;
1057 
1061  CenterChooser<Distance>* chooseCenters_;
1062 
1064 };
1065 
1066 }
1067 
1068 #endif //FLANN_KMEANS_INDEX_H_
int count
int points
double Distance(const Point3D< Real > &p1, const Point3D< Real > &p2)
cmdLineReadable * params[]
#define NULL
core::Tensor result
Definition: VtkUtils.cpp:76
bool copy
Definition: VtkUtils.cpp:74
Generic handle to any of the 8 types of E57 element objects.
bool test(size_t index) const
int getClusterCenters(Matrix< DistanceType > &centers)
Definition: kmeans_index.h:301
void addPoints(const Matrix< ElementType > &points, float rebuild_threshold=2)
Incrementally add points to the index.
Definition: kmeans_index.h:215
KMeansIndex & operator=(KMeansIndex other)
Definition: kmeans_index.h:158
NNIndex< Distance > BaseClass
Definition: kmeans_index.h:89
virtual ~KMeansIndex()
Definition: kmeans_index.h:187
void serialize(Archive &ar)
Definition: kmeans_index.h:234
Distance::ResultType DistanceType
Definition: kmeans_index.h:87
BaseClass * clone() const
Definition: kmeans_index.h:193
KMeansIndex(const Matrix< ElementType > &inputData, const IndexParams &params=KMeansIndexParams(), Distance d=Distance())
Definition: kmeans_index.h:107
KMeansIndex(const KMeansIndex &other)
Definition: kmeans_index.h:146
int usedMemory() const
Definition: kmeans_index.h:208
void saveIndex(FILE *stream)
Definition: kmeans_index.h:260
KMeansIndex(const IndexParams &params=KMeansIndexParams(), Distance d=Distance())
Definition: kmeans_index.h:131
void loadIndex(FILE *stream)
Definition: kmeans_index.h:266
flann_algorithm_t getType() const
Definition: kmeans_index.h:95
virtual void buildIndex()
Definition: nn_index.h:125
Distance::ElementType ElementType
Definition: kmeans_index.h:86
void set_cb_index(float index)
Definition: kmeans_index.h:199
bool needs_vector_space_distance
Definition: kmeans_index.h:91
void findNeighbors(ResultSet< DistanceType > &result, const ElementType *vec, const SearchParams &searchParams) const
Definition: kmeans_index.h:283
static int info(const char *fmt,...)
Definition: logger.h:127
size_t rows
Definition: matrix.h:72
std::vector< ElementType * > points_
Definition: nn_index.h:876
size_t veclen_
Definition: nn_index.h:846
size_t size() const
Definition: nn_index.h:201
DynamicBitset removed_points_
Definition: nn_index.h:861
size_t size_
Definition: nn_index.h:836
void setDataset(const Matrix< ElementType > &dataset)
Definition: nn_index.h:746
Distance distance_
Definition: nn_index.h:823
virtual void buildIndex()
Definition: nn_index.h:125
void extendDataset(const Matrix< ElementType > &new_points)
Definition: nn_index.h:763
IndexParams index_params_
Definition: nn_index.h:851
size_t size_at_build_
Definition: nn_index.h:841
int max(int a, int b)
Definition: cutil_math.h:48
@ FLANN_CHECKS_UNLIMITED
Definition: defines.h:147
flann_algorithm_t
Definition: defines.h:80
@ FLANN_INDEX_KMEANS
Definition: defines.h:83
flann_centers_init_t
Definition: defines.h:96
@ FLANN_CENTERS_KMEANSPP
Definition: defines.h:99
@ FLANN_CENTERS_RANDOM
Definition: defines.h:97
@ FLANN_CENTERS_GONZALES
Definition: defines.h:98
static double dist(double x1, double y1, double x2, double y2)
Definition: lsd.c:207
void swap(optional< T > &x, optional< T > &y) noexcept(noexcept(x.swap(y)))
Definition: Optional.h:890
const binary_object make_binary_object(void *t, size_t size)
T get_param(const IndexParams &params, std::string name, const T &default_value)
Definition: params.h:95
std::map< std::string, any > IndexParams
Definition: params.h:51
void swap(cloudViewer::core::SmallVectorImpl< T > &LHS, cloudViewer::core::SmallVectorImpl< T > &RHS)
Implement std::swap in terms of SmallVector swap.
Definition: SmallVector.h:1370
#define USING_BASECLASS_SYMBOLS
Definition: nn_index.h:887
struct Index Index
Definition: sqlite3.c:14646
KMeansIndexParams(int branching=32, int iterations=11, flann_centers_init_t centers_init=FLANN_CENTERS_RANDOM, float cb_index=0.2)
Definition: kmeans_index.h:60
Definition: lsd.c:149