ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
confusionmatrix.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include "confusionmatrix.h"
9 
10 #include <CVConst.h>
11 #include <CVLog.h>
12 #include <QtCompat.h>
13 
14 #include <QBrush>
15 #include <QFile>
16 #include <algorithm>
17 #include <iostream>
18 #include <iterator>
19 #include <set>
20 
21 #include "ui_confusionmatrix.h"
22 
23 static QColor GetColor(double value, double r1, double g1, double b1) {
24  double r0 = 255.0;
25  double g0 = 255.0;
26  double b0 = 255.0;
27  if (value < 0.05) {
28  value = 0.05;
29  } else if (value > 0.95) {
30  value = 0.95;
31  }
32  int r = static_cast<int>((r1 - r0) * value + r0);
33  int g = static_cast<int>((g1 - g0) * value + g0);
34  int b = static_cast<int>((b1 - b0) * value + b0);
35  // CVLog::Warning("value " + QString::number(value) + " (" +
36  // QString::number(r) + ", " + QString::number(g) + ", " +
37  // QString::number(b)
38  //+ ")");
39  return QColor(r, g, b);
40 }
41 
42 ConfusionMatrix::ConfusionMatrix(const std::vector<ScalarType>& actual,
43  const std::vector<ScalarType>& predicted)
44  : nbClasses(0),
45  ui(new Ui::ConfusionMatrix),
46  m_overallAccuracy(0.0f)
47 
48 {
49  ui->setupUi(this);
50  this->setWindowFlag(Qt::WindowStaysOnTopHint);
51 
52  compute(actual, predicted);
53 
54  this->ui->tableWidget->resizeColumnsToContents();
55  this->ui->tableWidget->setSizeAdjustPolicy(
56  QAbstractScrollArea::AdjustToContents);
57  QSize tableSize = this->ui->tableWidget->sizeHint();
58  QSize widgetSize = QSize(tableSize.width() + 30, tableSize.height() + 50);
59  this->setMinimumSize(widgetSize);
60 }
61 
63  delete ui;
64  ui = nullptr;
65 }
66 
68  cv::Mat& matrix, cv::Mat& precisionRecallF1Score, cv::Mat& vec_TP_FN) {
69  int nbClasses = matrix.rows;
70 
71  // compute precision
72  for (int predictedIdx = 0; predictedIdx < nbClasses; predictedIdx++) {
73  float TP = 0;
74  float FP = 0;
75  for (int realIdx = 0; realIdx < nbClasses; realIdx++) {
76  if (realIdx == predictedIdx)
77  TP = matrix.at<int>(realIdx, realIdx);
78  else
79  FP += matrix.at<int>(realIdx, predictedIdx);
80  }
81  float TP_FP = TP + FP;
82  if (TP_FP == 0)
83  precisionRecallF1Score.at<float>(predictedIdx, PRECISION) =
84  NAN_VALUE;
85  else
86  precisionRecallF1Score.at<float>(predictedIdx, PRECISION) =
87  TP / TP_FP;
88  }
89 
90  // compute recall
91  for (int realIdx = 0; realIdx < nbClasses; realIdx++) {
92  float TP = 0;
93  float FN = 0;
94  for (int predictedIdx = 0; predictedIdx < nbClasses; predictedIdx++) {
95  if (realIdx == predictedIdx)
96  TP = matrix.at<int>(realIdx, realIdx);
97  else
98  FN += matrix.at<int>(realIdx, predictedIdx);
99  }
100  float TP_FN = TP + FN;
101  if (TP_FN == 0)
102  precisionRecallF1Score.at<float>(realIdx, RECALL) = NAN_VALUE;
103  else
104  precisionRecallF1Score.at<float>(realIdx, RECALL) = TP / TP_FN;
105  vec_TP_FN.at<int>(realIdx, 0) = TP_FN;
106  }
107 
108  // compute F1-score
109  for (int realIdx = 0; realIdx < nbClasses; realIdx++) {
110  float den = precisionRecallF1Score.at<float>(realIdx, PRECISION) +
111  precisionRecallF1Score.at<float>(realIdx, RECALL);
112  if (den == 0)
113  precisionRecallF1Score.at<float>(realIdx, F1_SCORE) = NAN_VALUE;
114  else
115  precisionRecallF1Score.at<float>(realIdx, F1_SCORE) =
116  2 * precisionRecallF1Score.at<float>(realIdx, PRECISION) *
117  precisionRecallF1Score.at<float>(realIdx, RECALL) / den;
118  }
119 }
120 
122  int nbClasses = matrix.rows;
123  float totalTrue = 0;
124  float totalFalse = 0;
125 
126  m_overallAccuracy = 0.0;
127 
128  for (int realIdx = 0; realIdx < nbClasses; realIdx++) {
129  for (int predictedIdx = 0; predictedIdx < nbClasses; predictedIdx++) {
130  if (realIdx == predictedIdx)
131  totalTrue += matrix.at<int>(realIdx, realIdx);
132  else
133  totalFalse += matrix.at<int>(realIdx, predictedIdx);
134  }
135  }
136  if ((totalTrue + totalFalse) != 0)
137  m_overallAccuracy = totalTrue / (totalTrue + totalFalse);
138  else
139  m_overallAccuracy = NAN_VALUE;
140 
141  return m_overallAccuracy;
142 }
143 
144 void ConfusionMatrix::compute(const std::vector<ScalarType>& actual,
145  const std::vector<ScalarType>& predicted) {
146  int idxActual;
147  int idxPredicted;
148  int actualClass;
149  int predictedClass;
150 
151  // get the set of classes with the contents of the actual classes
152  std::set<ScalarType> classes(actual.begin(), actual.end());
153  int nbClasses = static_cast<int>(classes.size());
154  confusionMatrix = cv::Mat(nbClasses, nbClasses, CV_32S, cv::Scalar(0));
155  precisionRecallF1Score = cv::Mat(nbClasses, 3, CV_32F, cv::Scalar(0));
156  cv::Mat vec_TP_FN(nbClasses, 1, CV_32S, cv::Scalar(0));
157 
158  // fill the confusion matrix
159  for (int i = 0; i < actual.size(); i++) {
160  actualClass = actual.at(i);
161  idxActual = std::distance(classes.begin(), classes.find(actualClass));
162  predictedClass = predicted.at(i);
163  idxPredicted =
164  std::distance(classes.begin(), classes.find(predictedClass));
165  confusionMatrix.at<int>(idxActual, idxPredicted)++;
166  }
167 
168  // compute precision recall F1-score
169  computePrecisionRecallF1Score(confusionMatrix, precisionRecallF1Score,
170  vec_TP_FN);
171  float overallAccuracy = computeOverallAccuracy(confusionMatrix);
172 
173  // display the overall accuracy
174  this->ui->label_overallAccuracy->setText(
175  QString::number(overallAccuracy, 'g', 2));
176 
177  std::set<ScalarType>::iterator itB = classes.begin();
178  std::set<ScalarType>::iterator itE = classes.end();
179  class_numbers.assign(itB, itE);
180 
181  // BUILD THE QTABLEWIDGET
182 
183  this->ui->tableWidget->setColumnCount(
184  2 + nbClasses +
185  3); // +2 for titles, +3 for precision / recall / F1-score
186  this->ui->tableWidget->setRowCount(2 + nbClasses);
187  // create a font for the table widgets
188  QFont font;
189  font.setBold(true);
190  QTableWidgetItem* newItem = nullptr;
191  // set the row and column names
192  this->ui->tableWidget->setSpan(0, 0, 2, 2); // empty area
193  this->ui->tableWidget->setSpan(0, 2, 1, nbClasses); // 'Predicted' header
194  this->ui->tableWidget->setSpan(2, 0, nbClasses, 1); // 'Actual' header
195  this->ui->tableWidget->setSpan(0, 2 + nbClasses, 1, 3); // empty area
196  // Predicted
197  newItem = new QTableWidgetItem("Predicted");
198  newItem->setFont(font);
199  newItem->setBackground(Qt::lightGray);
200  newItem->setTextAlignment(Qt::AlignCenter);
201  this->ui->tableWidget->setItem(0, 2, newItem);
202  // Real
203  newItem = new QTableWidgetItem("Real");
204  newItem->setFont(font);
205  newItem->setBackground(Qt::lightGray);
206  newItem->setTextAlignment(Qt::AlignCenter);
207  this->ui->tableWidget->setItem(2, 0, newItem);
208  // add precision / recall / F1-score headers
209  newItem = new QTableWidgetItem("Precision");
210  newItem->setToolTip("TP / (TP + FP)");
211  newItem->setFont(font);
212  this->ui->tableWidget->setItem(1, 2 + nbClasses + PRECISION, newItem);
213  newItem = new QTableWidgetItem("Recall");
214  newItem->setToolTip("TP / (TP + FN)");
215  newItem->setFont(font);
216  this->ui->tableWidget->setItem(1, 2 + nbClasses + RECALL, newItem);
217  newItem = new QTableWidgetItem("F1-score");
218  newItem->setToolTip(
219  "Harmonic mean of precision and recall (the closer to 1 the "
220  "better)\n2 x precision x recall / (precision + recall)");
221  newItem->setFont(font);
222  this->ui->tableWidget->setItem(1, 2 + nbClasses + F1_SCORE, newItem);
223  // add column names and row names
224  for (int idx = 0; idx < class_numbers.size(); idx++) {
225  QString str = QString::number(class_numbers[idx]);
226  newItem = new QTableWidgetItem(str);
227  newItem->setFont(font);
228  this->ui->tableWidget->setItem(1, 2 + idx, newItem);
229  newItem = new QTableWidgetItem(str);
230  newItem->setFont(font);
231  this->ui->tableWidget->setItem(2 + idx, 1, newItem);
232  }
233 
234  // FILL THE QTABLEWIDGET
235 
236  // add the confusion matrix values
237  for (int row = 0; row < nbClasses; row++)
238  for (int column = 0; column < nbClasses; column++) {
239  double val = confusionMatrix.at<int>(row, column);
240  QTableWidgetItem* newItem =
241  new QTableWidgetItem(QString::number(val));
242  if (row == column) {
243  newItem->setBackground(
244  GetColor(val / vec_TP_FN.at<int>(row, 0), 0, 128, 255));
245  } else {
246  newItem->setBackground(
247  GetColor(val / vec_TP_FN.at<int>(row, 0), 200, 50, 50));
248  }
249  this->ui->tableWidget->setItem(2 + row, +2 + column, newItem);
250  }
251 
252  // set precision / recall / F1-score values
253  for (int realIdx = 0; realIdx < nbClasses; realIdx++) {
254  newItem = new QTableWidgetItem(QString::number(
255  precisionRecallF1Score.at<float>(realIdx, PRECISION), 'g', 2));
256  this->ui->tableWidget->setItem(2 + realIdx, 2 + nbClasses + PRECISION,
257  newItem);
258  newItem = new QTableWidgetItem(QString::number(
259  precisionRecallF1Score.at<float>(realIdx, RECALL), 'g', 2));
260  this->ui->tableWidget->setItem(2 + realIdx, 2 + nbClasses + RECALL,
261  newItem);
262  newItem = new QTableWidgetItem(QString::number(
263  precisionRecallF1Score.at<float>(realIdx, F1_SCORE), 'g', 2));
264  this->ui->tableWidget->setItem(2 + realIdx, 2 + nbClasses + F1_SCORE,
265  newItem);
266  }
267 }
268 
269 void ConfusionMatrix::setSessionRun(QString session, int run) {
270  QString label;
271 
272  label = session + " / " + QString::number(run);
273 
274  this->ui->label_sessionRun->setText(label);
275 }
276 
277 bool ConfusionMatrix::save(QString filePath) {
278  QFile file(filePath);
279 
280  if (!file.open(QIODevice::WriteOnly | QIODevice::Text)) {
281  CVLog::Error("impossible to open file: " + filePath);
282  return false;
283  }
284 
285  QTextStream stream(&file);
286  stream << "# columns: predicted classes\n# rows: actual classes\n";
287  stream << "# last three colums: precision / recall / F1-score\n";
288  for (auto class_number : class_numbers) {
289  stream << class_number << " ";
290  }
291  stream << QtCompat::endl;
292  for (int row = 0; row < confusionMatrix.rows; row++) {
293  stream << class_numbers.at(row) << " ";
294  for (int col = 0; col < confusionMatrix.cols; col++) {
295  stream << confusionMatrix.at<int>(row, col) << " ";
296  }
297  stream << precisionRecallF1Score.at<float>(row, PRECISION) << " ";
298  stream << precisionRecallF1Score.at<float>(row, RECALL) << " ";
299  stream << precisionRecallF1Score.at<float>(row, F1_SCORE)
300  << QtCompat::endl;
301  }
302 
303  file.close();
304 
305  return true;
306 }
307 
308 float ConfusionMatrix::getOverallAccuracy() { return m_overallAccuracy; }
constexpr ScalarType NAN_VALUE
NaN as a ScalarType value.
Definition: CVConst.h:76
static bool Error(const char *format,...)
Display an error dialog with formatted message.
Definition: CVLog.cpp:143
~ConfusionMatrix() override
void computePrecisionRecallF1Score(cv::Mat &matrix, cv::Mat &precisionRecallF1Score, cv::Mat &vec_TP_FN)
void setSessionRun(QString session, int run)
void compute(const std::vector< ScalarType > &actual, const std::vector< ScalarType > &predicted)
ConfusionMatrix(const std::vector< ScalarType > &actual, const std::vector< ScalarType > &predicted)
bool save(QString filePath)
float computeOverallAccuracy(cv::Mat &matrix)
static QColor GetColor(double value, double r1, double g1, double b1)
QTextStream & endl(QTextStream &stream)
Definition: QtCompat.h:718