ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
q3DMASCClassifier.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include "q3DMASCClassifier.h"
9 
10 // Local
11 #include "ScalarFieldWrappers.h"
12 #include "q3DMASCTools.h"
13 
14 // qCC_db
15 #include <CVLog.h>
16 #include <ecvDisplayTools.h>
17 #include <ecvPointCloud.h>
18 #include <ecvProgressDialog.h>
19 #include <ecvScalarField.h>
20 
21 // qPDALIO
22 #include "../../../core/IO/qPDALIO/include/LASFields.h"
23 
24 // CCPluginAPI
25 #include <ecvMainAppInterface.h>
26 
27 // Qt
28 #include <QCoreApplication>
29 #include <QMessageBox>
30 #include <QProgressDialog>
31 #include <QtConcurrent>
32 
33 #include "confusionmatrix.h"
34 #include "qTrain3DMASCDialog.h"
35 
36 #if defined(_OPENMP)
37 #include <omp.h>
38 #endif
39 
40 #if defined(CV_MAC_OS) || defined(CV_LINUX)
41 #include <unistd.h>
42 #endif
43 
44 using namespace masc;
45 
47 
48 bool Classifier::isValid() const {
49  return (m_rtrees && m_rtrees->isClassifier() && m_rtrees->isTrained());
50 }
51 
53  const ccPointCloud* cloud) {
54  IScalarFieldWrapper::Shared source(nullptr);
55 
56  switch (fs.type) {
58  assert(!fs.name.isEmpty());
59  int sfIdx = cloud->getScalarFieldIndexByName(qPrintable(fs.name));
60  if (sfIdx >= 0) {
61  source.reset(
62  new ScalarFieldWrapper(cloud->getScalarField(sfIdx)));
63  } else {
65  QObject::tr("Internal error: unknown scalar field '%1'")
66  .arg(fs.name));
67  return IScalarFieldWrapper::Shared(nullptr);
68  }
69  } break;
70 
72  source.reset(new DimScalarFieldWrapper(
74  break;
76  source.reset(new DimScalarFieldWrapper(
78  break;
80  source.reset(new DimScalarFieldWrapper(
82  break;
83 
85  source.reset(new ColorScalarFieldWrapper(
87  break;
89  source.reset(new ColorScalarFieldWrapper(
91  break;
93  source.reset(new ColorScalarFieldWrapper(
95  break;
96  }
97 
98  return source;
99 }
100 
101 bool Classifier::classify(const Feature::Source::Set& featureSources,
102  ccPointCloud* cloud,
103  QString& errorMessage,
104  QWidget* parentWidget /*=nullptr*/,
105  ecvMainAppInterface* app /*nullptr*/
106 ) {
107  if (!cloud) {
108  assert(false);
109  errorMessage = QObject::tr("Invalid input");
110  return false;
111  }
112 
113  if (!isValid()) {
114  errorMessage = QObject::tr("Invalid classifier");
115  return false;
116  }
117 
118  if (featureSources.empty()) {
119  errorMessage = QObject::tr(
120  "Training method called without any feature (source)?!");
121  return false;
122  }
123 
124  // add a ccConfidence value if needed
125  int cvConfidenceIdx =
126  cloud->getScalarFieldIndexByName("Classification_confidence");
127  if (cvConfidenceIdx >= 0) // if the scalar field exists, delete it
128  cloud->deleteScalarField(cvConfidenceIdx);
129  cvConfidenceIdx = cloud->addScalarField("Classification_confidence");
130  ccScalarField* cvConfidenceSF =
131  static_cast<ccScalarField*>(cloud->getScalarField(cvConfidenceIdx));
132 
133  // look for the classification field
134  cloudViewer::ScalarField* classificationSF =
136  ccScalarField* classifSFBackup = nullptr;
137 
138  if (classificationSF) // save classification field (if any) by renaming it
139  // "Classification_backup"
140  {
142  "Classification SF found: copy it in Classification_backup, a "
143  "confusion matrix will be generated");
144  // delete Classification_backup field (if any)
145  int sfIdx = cloud->getScalarFieldIndexByName("Classification_backup");
146  if (sfIdx >= 0) cloud->deleteScalarField(sfIdx);
147 
148  classificationSF->setName(
149  "Classification_backup"); // rename the classification field
150  classifSFBackup = static_cast<ccScalarField*>(classificationSF);
151  }
152 
153  // create the classification SF
154  ccScalarField* _classificationSF =
156  if (!_classificationSF->resizeSafe(cloud->size())) {
157  _classificationSF->release();
158  errorMessage = QObject::tr("Not enough memory");
159  return false;
160  }
161  cloud->addScalarField(_classificationSF);
162  classificationSF = _classificationSF;
163 
164  assert(classificationSF);
165  classificationSF->fill(0); // 0 = no classification?
166 
167  int sampleCount = static_cast<int>(cloud->size());
168  int attributesPerSample = static_cast<int>(featureSources.size());
169 
170  CVLog::Print(
171  QObject::tr("[3DMASC] Classifying %1 points with %2 feature(s)")
172  .arg(sampleCount)
173  .arg(attributesPerSample));
174 
175  // create the field wrappers
176  std::vector<IScalarFieldWrapper::Shared> wrappers;
177  {
178  wrappers.reserve(attributesPerSample);
179  for (int fIndex = 0; fIndex < attributesPerSample; ++fIndex) {
180  const Feature::Source& fs = featureSources[fIndex];
181 
182  IScalarFieldWrapper::Shared source = GetSource(fs, cloud);
183  if (!source || !source->isValid()) {
184  assert(false);
185  errorMessage =
186  QObject::tr("Internal error: invalid source '%1'")
187  .arg(fs.name);
188  return false;
189  }
190 
191  wrappers.push_back(source);
192  }
193  }
194 
195  QScopedPointer<ecvProgressDialog> pDlg;
196  if (parentWidget) {
197  pDlg.reset(new ecvProgressDialog(parentWidget));
198  pDlg->setLabelText(QString("Classify (%1 points)").arg(sampleCount));
199  pDlg->show();
200  QCoreApplication::processEvents();
201  }
202  cloudViewer::NormalizedProgress nProgress(pDlg.data(), cloud->size());
203 
204  bool success = true;
205  int numberOfTrees = static_cast<int>(m_rtrees->getRoots().size());
206  bool cancelled = false;
207 
208 #ifndef _DEBUG
209 #if defined(_OPENMP)
210 #pragma omp parallel for num_threads(omp_get_max_threads() - 2)
211 #endif
212 #endif
213  for (int i = 0; i < static_cast<int>(cloud->size()); ++i) {
214  {
215  // allocate the data matrix
216  cv::Mat test_data;
217  try {
218  test_data.create(1, attributesPerSample, CV_32FC1);
219  } catch (const cv::Exception& cvex) {
220  errorMessage = cvex.msg.c_str();
221  success = false;
222  cancelled = true;
223  }
224 
225  if (!cancelled) {
226  for (int fIndex = 0; fIndex < attributesPerSample; ++fIndex) {
227  double value = wrappers[fIndex]->pointValue(i);
228  test_data.at<float>(0, fIndex) = static_cast<float>(value);
229  }
230 
231  float predictedClass =
232  m_rtrees->predict(test_data.row(0), cv::noArray(),
233  cv::ml::DTrees::PREDICT_MAX_VOTE);
234  classificationSF->setValue(i, static_cast<int>(predictedClass));
235  // compute the confidence
236  cv::Mat result;
237  m_rtrees->getVotes(test_data, result,
238  cv::ml::DTrees::PREDICT_MAX_VOTE);
239  int classIndex = -1;
240  for (int col = 0; col < result.cols;
241  col++) // look for the index of the predicted class
242  if (predictedClass == result.at<int>(0, col)) {
243  classIndex = col;
244  break;
245  }
246  if (classIndex != -1) {
247  float nbVotes = result.at<int>(
248  1, classIndex); // get the number of votes
249  cvConfidenceSF->setValue(
250  i,
251  static_cast<ScalarType>(
252  nbVotes /
253  numberOfTrees)); // compute the confidence
254  } else
255  cvConfidenceSF->setValue(i, NAN_VALUE);
256 
257  if (pDlg && !nProgress.oneStep()) {
258  // process cancelled by the user
259  success = false;
260  cancelled = true;
261  }
262  }
263  }
264  }
265 
266  classificationSF->computeMinAndMax();
267  cvConfidenceSF->computeMinAndMax();
268 
269  // show the classification field by default
270  {
271  int classifSFIdx =
272  cloud->getScalarFieldIndexByName(classificationSF->getName());
273  cloud->setCurrentDisplayedScalarField(classifSFIdx);
274  cloud->showSF(true);
275  }
276 
277  if (parentWidget) {
279  cloud->setRedrawFlagRecursive(true);
280  ecvDisplayTools::RedrawDisplay(false, true);
281  QCoreApplication::processEvents();
282  }
283 
284  if (classifSFBackup != nullptr) {
285  if (app) {
286  ConfusionMatrix* confusionMatrix =
287  new ConfusionMatrix(*classifSFBackup, *classificationSF);
288  }
289  }
290 
291  return success;
292 }
293 
294 bool Classifier::evaluate(const Feature::Source::Set& featureSources,
295  ccPointCloud* testCloud,
296  AccuracyMetrics& metrics,
297  QString& errorMessage,
298  Train3DMASCDialog& train3DMASCDialog,
299  cloudViewer::ReferenceCloud* testSubset /*=nullptr=*/,
300  QString outputSFName /*=QString()*/,
301  QWidget* parentWidget /*=nullptr*/,
302  ecvMainAppInterface* app /*=nullptr*/) {
303  if (!testCloud) {
304  // invalid input
305  assert(false);
306  errorMessage = QObject::tr("Invalid input cloud");
307  return false;
308  }
309  metrics.sampleCount = metrics.goodGuess = 0;
310  metrics.ratio = 0.0f;
311 
312  if (!m_rtrees || !m_rtrees->isTrained()) {
313  errorMessage = QObject::tr("Classifier hasn't been trained yet");
314  return false;
315  }
316 
317  if (featureSources.empty()) {
318  errorMessage = QObject::tr(
319  "Training method called without any feature (source)?!");
320  return false;
321  }
322  if (testSubset && testSubset->getAssociatedCloud() != testCloud) {
323  errorMessage = QObject::tr(
324  "Invalid test subset (associated point cloud is different)");
325  return false;
326  }
327 
328  // look for the classification field
330  if (!classifSF || classifSF->size() < testCloud->size()) {
331  assert(false);
332  errorMessage = QObject::tr(
333  "Missing/invalid 'Classification' field on input cloud");
334  return false;
335  }
336 
337  ccScalarField* outSF = nullptr;
338  ccScalarField* cvConfidenceSF = nullptr;
339 
340  if (!outputSFName.isEmpty()) {
341  int outIdx =
342  testCloud->getScalarFieldIndexByName(qPrintable(outputSFName));
343  if (outIdx >= 0)
344  testCloud->deleteScalarField(outIdx);
345  else
346  CVLog::Print("add " + outputSFName + " to the TEST cloud");
347  outIdx = testCloud->addScalarField(qPrintable(outputSFName));
348  outSF = static_cast<ccScalarField*>(testCloud->getScalarField(outIdx));
349  }
350 
351  if (outSF) // add a Classification_confidence value to the test cloud if
352  // needed
353  {
354  int cvConfidenceIdx = testCloud->getScalarFieldIndexByName(
355  "Classification_confidence");
356  if (cvConfidenceIdx >= 0) // if the scalar field exists, delete it
357  testCloud->deleteScalarField(cvConfidenceIdx);
358  else
359  CVLog::Print("add Classification_confidence to the TEST cloud");
360  cvConfidenceIdx =
361  testCloud->addScalarField("Classification_confidence");
362  cvConfidenceSF = static_cast<ccScalarField*>(
363  testCloud->getScalarField(cvConfidenceIdx));
364  }
365 
366  unsigned testSampleCount =
367  (testSubset ? testSubset->size() : testCloud->size());
368  int attributesPerSample = static_cast<int>(featureSources.size());
369 
370  CVLog::Print(
371  QObject::tr("[3DMASC] Testing data: %1 samples with %2 feature(s)")
372  .arg(testSampleCount)
373  .arg(attributesPerSample));
374 
375  // allocate the data matrix
376  cv::Mat test_data;
377  try {
378  test_data.create(static_cast<int>(testSampleCount), attributesPerSample,
379  CV_32FC1);
380  } catch (const cv::Exception& cvex) {
381  errorMessage = cvex.msg.c_str();
382  return false;
383  }
384 
385  QScopedPointer<ecvProgressDialog> pDlg;
386  if (parentWidget) {
387  pDlg.reset(new ecvProgressDialog(parentWidget));
388  pDlg->setLabelText(QString("Evaluating the classifier on %1 points")
389  .arg(testSampleCount));
390  pDlg->show();
391  QCoreApplication::processEvents();
392  }
393  cloudViewer::NormalizedProgress nProgress(pDlg.data(), testSampleCount);
394 
395  // fill the data matrix
396  for (int fIndex = 0; fIndex < attributesPerSample; ++fIndex) {
397  const Feature::Source& fs = featureSources[fIndex];
398  IScalarFieldWrapper::Shared source = GetSource(fs, testCloud);
399  if (!source || !source->isValid()) {
400  assert(false);
401  errorMessage = QObject::tr("Internal error: invalid source '%1'")
402  .arg(fs.name);
403  return false;
404  }
405 
406  for (unsigned i = 0; i < testSampleCount; ++i) {
407  unsigned pointIndex =
408  (testSubset ? testSubset->getPointGlobalIndex(i) : i);
409  double value = source->pointValue(pointIndex);
410  test_data.at<float>(i, fIndex) = static_cast<float>(value);
411  }
412  }
413 
414  int numberOfTrees = static_cast<int>(m_rtrees->getRoots().size());
415 
416  // estimate the efficiency of the classifier
417  std::vector<ScalarType> actualClass(testSampleCount);
418  std::vector<ScalarType> predictectedClass(testSampleCount);
419  {
420  metrics.sampleCount = testSampleCount;
421  metrics.goodGuess = 0;
422 
423  for (unsigned i = 0; i < testSampleCount; ++i) {
424  unsigned pointIndex =
425  (testSubset ? testSubset->getPointGlobalIndex(i) : i);
426  ScalarType pointClass = classifSF->getValue(pointIndex);
427  int iClass = static_cast<int>(pointClass);
428  // if (iClass < 0 || iClass > 255)
429  //{
430  // errorMessage = QObject::tr("Classification values out of range
431  //(0-255)"); return false;
432  // }
433 
434  float fPredictedClass =
435  m_rtrees->predict(test_data.row(i), cv::noArray(),
436  cv::ml::DTrees::PREDICT_MAX_VOTE);
437  int iPredictedClass = static_cast<int>(fPredictedClass);
438  actualClass.at(i) = iClass;
439  predictectedClass.at(i) = iPredictedClass;
440  if (iPredictedClass == iClass) {
441  ++metrics.goodGuess;
442  }
443  if (outSF) {
444  outSF->setValue(pointIndex,
445  static_cast<ScalarType>(iPredictedClass));
446  if (cvConfidenceSF) {
447  // compute the confidence
448  cv::Mat result;
449  m_rtrees->getVotes(test_data.row(i), result,
450  cv::ml::DTrees::PREDICT_MAX_VOTE);
451  int classIndex = -1;
452  for (int col = 0; col < result.cols;
453  col++) // look for the index of the predicted class
454  if (iPredictedClass == result.at<int>(0, col)) {
455  classIndex = col;
456  break;
457  }
458  if (classIndex != -1) {
459  float nbVotes = result.at<int>(
460  1, classIndex); // get the number of votes
461  cvConfidenceSF->setValue(
462  i, static_cast<ScalarType>(
463  nbVotes /
464  numberOfTrees)); // compute the
465  // confidence
466  } else
467  cvConfidenceSF->setValue(i, NAN_VALUE);
468  }
469  }
470 
471  if (pDlg && !nProgress.oneStep()) {
472  // process cancelled by the user
473  return false;
474  }
475  }
476 
477  if (outSF) outSF->computeMinAndMax();
478  if (cvConfidenceSF) cvConfidenceSF->computeMinAndMax();
479 
480  metrics.ratio =
481  static_cast<float>(metrics.goodGuess) / metrics.sampleCount;
482  }
483 
484  ConfusionMatrix* confusionMatrix =
485  new ConfusionMatrix(actualClass, predictectedClass);
486  train3DMASCDialog.addConfusionMatrixAndSaveTraces(confusionMatrix);
487  if (app) {
488  confusionMatrix->show();
489  }
490 
491  // show the Classification_prediction field by default
492  if (outSF) {
493  int classifSFIdx =
494  testCloud->getScalarFieldIndexByName(outSF->getName());
495  testCloud->setCurrentDisplayedScalarField(classifSFIdx);
496  testCloud->showSF(true);
497  }
498 
499  if (parentWidget) {
501  testCloud->setRedraw(true);
502  ecvDisplayTools::RedrawDisplay(false, true);
503  QCoreApplication::processEvents();
504  }
505 
506  return true;
507 }
508 
509 bool Classifier::train(const ccPointCloud* cloud,
510  const RandomTreesParams& params,
511  const Feature::Source::Set& featureSources,
512  QString& errorMessage,
513  cloudViewer::ReferenceCloud* trainSubset /*=nullptr*/,
514  ecvMainAppInterface* app /*=nullptr*/,
515  QWidget* parentWidget /*=nullptr*/) {
516  if (featureSources.empty()) {
517  errorMessage = QObject::tr(
518  "Training method called without any feature (source)?!");
519  return false;
520  }
521  if (!cloud) {
522  errorMessage = QObject::tr("Invalid input cloud");
523  return false;
524  }
525 
526  if (trainSubset && trainSubset->getAssociatedCloud() != cloud) {
527  errorMessage = QObject::tr(
528  "Invalid train subset (associated point cloud is different)");
529  return false;
530  }
531 
532  // look for the classification field
534  if (!classifSF || classifSF->size() < cloud->size()) {
535  assert(false);
536  errorMessage = QObject::tr(
537  "Missing/invalid 'Classification' field on input cloud");
538  return false;
539  }
540 
541  int sampleCount =
542  static_cast<int>(trainSubset ? trainSubset->size() : cloud->size());
543  int attributesPerSample = static_cast<int>(featureSources.size());
544 
545  if (app) {
546  app->dispToConsole(
547  QString("[3DMASC] Training data: %1 samples with %2 feature(s)")
548  .arg(sampleCount)
549  .arg(attributesPerSample));
550  }
551 
552  cv::Mat training_data, train_labels;
553  try {
554  training_data.create(sampleCount, attributesPerSample, CV_32FC1);
555  train_labels.create(sampleCount, 1, CV_32FC1);
556  } catch (const cv::Exception& cvex) {
557  errorMessage = cvex.msg.c_str();
558  return false;
559  }
560 
561  // fill the classification labels vector
562  {
563  for (int i = 0; i < sampleCount; ++i) {
564  int pointIndex =
565  (trainSubset ? static_cast<int>(
566  trainSubset->getPointGlobalIndex(i))
567  : i);
568  ScalarType pointClass = classifSF->getValue(pointIndex);
569  int iClass = static_cast<int>(pointClass);
570  // if (iClass < 0 || iClass > 255)
571  //{
572  // errorMessage = QObject::tr("Classification values out of range
573  //(0-255)"); return false;
574  // }
575 
576  train_labels.at<float>(i) = static_cast<unsigned char>(iClass);
577  }
578  }
579 
580  // fill the training data matrix
581  for (int fIndex = 0; fIndex < attributesPerSample; ++fIndex) {
582  const Feature::Source& fs = featureSources[fIndex];
583 
584  IScalarFieldWrapper::Shared source = GetSource(fs, cloud);
585  if (!source || !source->isValid()) {
586  assert(false);
587  errorMessage = QObject::tr("Internal error: invalid source '%1'")
588  .arg(fs.name);
589  return false;
590  }
591 
592  for (int i = 0; i < sampleCount; ++i) {
593  int pointIndex =
594  (trainSubset ? static_cast<int>(
595  trainSubset->getPointGlobalIndex(i))
596  : i);
597  double value = source->pointValue(pointIndex);
598  training_data.at<float>(i, fIndex) = static_cast<float>(value);
599  }
600  }
601 
602  QScopedPointer<QProgressDialog> pDlg;
603  if (parentWidget) {
604  pDlg.reset(new QProgressDialog(parentWidget));
605  pDlg->setRange(0, 0); // infinite loop
606  pDlg->setLabelText("Training classifier");
607  pDlg->show();
608  QCoreApplication::processEvents();
609  }
610 
611  m_rtrees = cv::ml::RTrees::create();
612  m_rtrees->setMaxDepth(params.maxDepth);
613  m_rtrees->setMinSampleCount(params.minSampleCount);
614  m_rtrees->setRegressionAccuracy(0);
615  // If true then surrogate splits will be built. These splits allow to work
616  // with missing data and compute variable importance correctly. Default
617  // value is false.
618  m_rtrees->setUseSurrogates(false);
619  m_rtrees->setPriors(cv::Mat());
620  // m_rtrees->setMaxCategories(params.maxCategories); //not important?
621  m_rtrees->setCalculateVarImportance(true);
622  m_rtrees->setActiveVarCount(params.activeVarCount);
623  cv::TermCriteria terminationCriteria(
624  cv::TermCriteria::MAX_ITER, params.maxTreeCount,
625  std::numeric_limits<double>::epsilon());
626  m_rtrees->setTermCriteria(terminationCriteria);
627 
628  QFuture<bool> future = QtConcurrent::run([&]() {
629  // Code in this block will run in another thread
630  try {
631  cv::Mat sampleIndexes =
632  cv::Mat::zeros(1, training_data.rows, CV_8U);
633  // cv::Mat trainSamples = sampleIndexes.colRange(0,
634  // sampleCount);
635  // trainSamples.setTo(cv::Scalar::all(1));
636 
637  cv::Mat varTypes(training_data.cols + 1, 1, CV_8U);
638  varTypes.setTo(cv::Scalar::all(cv::ml::VAR_ORDERED));
639  varTypes.at<uchar>(training_data.cols) = cv::ml::VAR_CATEGORICAL;
640 
641  cv::Ptr<cv::ml::TrainData> trainData = cv::ml::TrainData::create(
642  training_data, cv::ml::ROW_SAMPLE,
643  train_labels, /* samples layout responses */
644  cv::noArray(), sampleIndexes, /* varIdx sampleIdx */
645  cv::noArray(), varTypes); // sampleWeights varType
646 
647  bool success = m_rtrees->train(trainData);
648  if (!success || !m_rtrees->isClassifier()) {
649  errorMessage = "Training failed";
650  return false;
651  }
652  } catch (const cv::Exception& cvex) {
653  m_rtrees.release();
654  errorMessage = cvex.msg.c_str();
655  return false;
656  } catch (const std::exception& stdex) {
657  errorMessage = stdex.what();
658  return false;
659  } catch (...) {
660  errorMessage = QObject::tr("Unknown error");
661  return false;
662  }
663 
664  return true;
665  });
666 
667  while (!future.isFinished()) {
668 #if defined(CV_WINDOWS)
669  ::Sleep(500);
670 #else
671  usleep(500 * 1000);
672 #endif
673  if (pDlg) {
674  if (pDlg->wasCanceled()) {
675  // future.cancel();
676  QMessageBox msgBox;
677  msgBox.setText(
678  "The training is still in progress, not possible to "
679  "cancel.");
680  msgBox.exec();
681  // break;
682  pDlg->reset();
683  pDlg->show();
684  }
685  pDlg->setValue(pDlg->value() + 1);
686  }
687  QCoreApplication::processEvents();
688  }
689 
690  if (pDlg) {
691  pDlg->close();
692  QCoreApplication::processEvents();
693  }
694 
695  if (future.isCanceled() || !future.result() || !m_rtrees->isTrained()) {
696  errorMessage = QObject::tr("Training failed for an unknown reason...");
697  m_rtrees.release();
698  return false;
699  }
700 
701  return true;
702 }
703 
705  QWidget* parentWidget /*=nullptr*/) const {
706  if (!m_rtrees) {
708  QObject::tr("Classifier hasn't been trained, can't save it"));
709  return false;
710  }
711 
712  // save the classifier
713  QProgressDialog pDlg(parentWidget);
714  pDlg.setRange(0, 0); // infinite loop
715  pDlg.setLabelText(QObject::tr("Saving classifier"));
716  pDlg.show();
717  QCoreApplication::processEvents();
718 
719  cv::String cvFilename = filename.toStdString();
720  m_rtrees->save(cvFilename);
721 
722  pDlg.close();
723  QCoreApplication::processEvents();
724 
725  CVLog::Print("Classifier file saved to: " +
726  QString::fromStdString(cvFilename));
727  return true;
728 }
729 
731  QWidget* parentWidget /*=nullptr*/) {
732  // load the classifier
733  QScopedPointer<QProgressDialog> pDlg;
734  if (parentWidget) {
735  pDlg.reset(new QProgressDialog(parentWidget));
736  pDlg->setRange(0, 0); // infinite loop
737  pDlg->setLabelText(QObject::tr("Loading classifier"));
738  pDlg->show();
739  QCoreApplication::processEvents();
740  }
741 
742  try {
743  m_rtrees = cv::ml::RTrees::load(filename.toStdString());
744  } catch (const cv::Exception& cvex) {
745  CVLog::Warning(cvex.msg.c_str());
746  CVLog::Error("Failed to load file: " + filename);
747  return false;
748  }
749 
750  if (pDlg) {
751  pDlg->close();
752  QCoreApplication::processEvents();
753  }
754 
755  if (m_rtrees->empty() || !m_rtrees->isClassifier()) {
756  CVLog::Error(QObject::tr("Loaded classifier is invalid"));
757  return false;
758  } else if (!m_rtrees->isTrained()) {
760  QObject::tr("Loaded classifier doesn't seem to be trained"));
761  }
762 
763  return true;
764 }
constexpr ScalarType NAN_VALUE
NaN as a ScalarType value.
Definition: CVConst.h:76
std::string filename
const char LAS_FIELD_NAMES[][28]
Definition: LASFields.h:63
@ LAS_CLASSIFICATION
Definition: LASFields.h:44
cmdLineReadable * params[]
core::Tensor result
Definition: VtkUtils.cpp:76
virtual void release()
Decrease counter and deletes object when 0.
Definition: CVShareable.cpp:35
static bool Warning(const char *format,...)
Prints out a formatted warning message in console.
Definition: CVLog.cpp:133
static bool Print(const char *format,...)
Prints out a formatted message in console.
Definition: CVLog.cpp:113
static bool Error(const char *format,...)
Display an error dialog with formatted message.
Definition: CVLog.cpp:143
QSharedPointer< IScalarFieldWrapper > Shared
3DMASC plugin 'train' dialog
void addConfusionMatrixAndSaveTraces(ConfusionMatrix *ptr)
virtual void setRedraw(bool state)
Sets entity redraw mode.
virtual void showSF(bool state)
Sets active scalarfield visibility.
void setRedrawFlagRecursive(bool redraw=false)
A 3D cloud and its associated features (color, normals, scalar fields, etc.)
void setCurrentDisplayedScalarField(int index)
Sets the currently displayed scalar field.
int addScalarField(const char *uniqueName) override
Creates a new scalar field and registers it.
void deleteScalarField(int index) override
Deletes a specific scalar field.
A scalar field associated to display-related parameters.
void computeMinAndMax() override
Determines the min and max values.
bool oneStep()
Increments total progress value of a single unit.
int getScalarFieldIndexByName(const char *name) const
Returns the index of a scalar field represented by its name.
ScalarField * getScalarField(int index) const
Returns a pointer to a specific scalar field.
unsigned size() const override
Definition: PointCloudTpl.h:38
A very simple point cloud (no point duplication)
virtual GenericIndexedCloudPersist * getAssociatedCloud()
Returns the associated (source) cloud.
unsigned size() const override
Returns the number of points.
virtual unsigned getPointGlobalIndex(unsigned localIndex) const
A simple scalar field (to be associated to a point cloud)
Definition: ScalarField.h:25
void fill(ScalarType fillValue=0)
Fills the array with a particular value.
Definition: ScalarField.h:77
virtual void computeMinAndMax()
Determines the min and max values.
Definition: ScalarField.h:123
ScalarType & getValue(std::size_t index)
Definition: ScalarField.h:92
void setValue(std::size_t index, ScalarType value)
Definition: ScalarField.h:96
const char * getName() const
Returns scalar field name.
Definition: ScalarField.h:43
void setName(const char *name)
Sets scalar field name.
Definition: ScalarField.cpp:22
bool resizeSafe(std::size_t count, bool initNewElements=false, ScalarType valueForNewElements=0)
Resizes memory (no exception thrown)
Definition: ScalarField.cpp:81
static void SetRedrawRecursive(bool redraw=false)
static void RedrawDisplay(bool only2D=false, bool forceRedraw=true)
Main application interface (for plugins)
virtual void dispToConsole(QString message, ConsoleMessageLevel level=STD_CONSOLE_MESSAGE)=0
Graphical progress indicator (thread-safe)
bool toFile(QString filename, QWidget *parentWidget=nullptr) const
Saves the classifier to file.
bool isValid() const
Returns whether the classifier is valid or not.
bool evaluate(const Feature::Source::Set &featureSources, ccPointCloud *testCloud, AccuracyMetrics &metrics, QString &errorMessage, Train3DMASCDialog &train3DMASCDialog, cloudViewer::ReferenceCloud *testSubset=nullptr, QString outputSFName=QString(), QWidget *parentWidget=nullptr, ecvMainAppInterface *app=nullptr)
Evaluates the classifier.
bool classify(const Feature::Source::Set &featureSources, ccPointCloud *cloud, QString &errorMessage, QWidget *parentWidget=nullptr, ecvMainAppInterface *app=nullptr)
Applies the classifier.
bool train(const ccPointCloud *cloud, const RandomTreesParams &params, const Feature::Source::Set &featureSources, QString &errorMessage, cloudViewer::ReferenceCloud *trainSubset=nullptr, ecvMainAppInterface *app=nullptr, QWidget *parentWidget=nullptr)
Train the classifier.
bool fromFile(QString filename, QWidget *parentWidget=nullptr)
Loads the classifier from file.
cv::Ptr< cv::ml::RTrees > m_rtrees
Random trees (OpenCV)
Classifier()
Default constructor.
static cloudViewer::ScalarField * GetClassificationSF(const ccPointCloud *cloud)
Helper: returns the classification SF associated to a cloud (if any)
void Sleep(int milliseconds)
Definition: Helper.cpp:278
unsigned char uchar
Definition: matrix.h:41
3DMASC classifier
static IScalarFieldWrapper::Shared GetSource(const Feature::Source &fs, const ccPointCloud *cloud)
cloudViewer::NormalizedProgress * nProgress
Classifier accuracy metrics.
Sources of values for this feature.
std::vector< Source > Set