21 #include "ui_confusionmatrix.h"
23 static QColor
GetColor(
double value,
double r1,
double g1,
double b1) {
29 }
else if (value > 0.95) {
32 int r =
static_cast<int>((r1 - r0) * value + r0);
33 int g =
static_cast<int>((g1 - g0) * value + g0);
34 int b =
static_cast<int>((b1 - b0) * value + b0);
39 return QColor(r, g, b);
43 const std::vector<ScalarType>& predicted)
46 m_overallAccuracy(0.0f)
50 this->setWindowFlag(Qt::WindowStaysOnTopHint);
54 this->ui->tableWidget->resizeColumnsToContents();
55 this->ui->tableWidget->setSizeAdjustPolicy(
56 QAbstractScrollArea::AdjustToContents);
57 QSize tableSize = this->ui->tableWidget->sizeHint();
58 QSize widgetSize = QSize(tableSize.width() + 30, tableSize.height() + 50);
59 this->setMinimumSize(widgetSize);
68 cv::Mat& matrix, cv::Mat& precisionRecallF1Score, cv::Mat& vec_TP_FN) {
69 int nbClasses = matrix.rows;
72 for (
int predictedIdx = 0; predictedIdx < nbClasses; predictedIdx++) {
75 for (
int realIdx = 0; realIdx < nbClasses; realIdx++) {
76 if (realIdx == predictedIdx)
77 TP = matrix.at<
int>(realIdx, realIdx);
79 FP += matrix.at<
int>(realIdx, predictedIdx);
81 float TP_FP = TP + FP;
83 precisionRecallF1Score.at<
float>(predictedIdx,
PRECISION) =
86 precisionRecallF1Score.at<
float>(predictedIdx,
PRECISION) =
91 for (
int realIdx = 0; realIdx < nbClasses; realIdx++) {
94 for (
int predictedIdx = 0; predictedIdx < nbClasses; predictedIdx++) {
95 if (realIdx == predictedIdx)
96 TP = matrix.at<
int>(realIdx, realIdx);
98 FN += matrix.at<
int>(realIdx, predictedIdx);
100 float TP_FN = TP + FN;
104 precisionRecallF1Score.at<
float>(realIdx,
RECALL) = TP / TP_FN;
105 vec_TP_FN.at<
int>(realIdx, 0) = TP_FN;
109 for (
int realIdx = 0; realIdx < nbClasses; realIdx++) {
110 float den = precisionRecallF1Score.at<
float>(realIdx,
PRECISION) +
111 precisionRecallF1Score.at<
float>(realIdx,
RECALL);
115 precisionRecallF1Score.at<
float>(realIdx,
F1_SCORE) =
116 2 * precisionRecallF1Score.at<
float>(realIdx,
PRECISION) *
117 precisionRecallF1Score.at<
float>(realIdx,
RECALL) / den;
122 int nbClasses = matrix.rows;
124 float totalFalse = 0;
126 m_overallAccuracy = 0.0;
128 for (
int realIdx = 0; realIdx < nbClasses; realIdx++) {
129 for (
int predictedIdx = 0; predictedIdx < nbClasses; predictedIdx++) {
130 if (realIdx == predictedIdx)
131 totalTrue += matrix.at<
int>(realIdx, realIdx);
133 totalFalse += matrix.at<
int>(realIdx, predictedIdx);
136 if ((totalTrue + totalFalse) != 0)
137 m_overallAccuracy = totalTrue / (totalTrue + totalFalse);
141 return m_overallAccuracy;
145 const std::vector<ScalarType>& predicted) {
152 std::set<ScalarType> classes(actual.begin(), actual.end());
153 int nbClasses =
static_cast<int>(classes.size());
154 confusionMatrix = cv::Mat(nbClasses, nbClasses, CV_32S, cv::Scalar(0));
155 precisionRecallF1Score = cv::Mat(nbClasses, 3, CV_32F, cv::Scalar(0));
156 cv::Mat vec_TP_FN(nbClasses, 1, CV_32S, cv::Scalar(0));
159 for (
int i = 0; i < actual.size(); i++) {
160 actualClass = actual.at(i);
161 idxActual = std::distance(classes.begin(), classes.find(actualClass));
162 predictedClass = predicted.at(i);
164 std::distance(classes.begin(), classes.find(predictedClass));
165 confusionMatrix.at<
int>(idxActual, idxPredicted)++;
174 this->ui->label_overallAccuracy->setText(
175 QString::number(overallAccuracy,
'g', 2));
177 std::set<ScalarType>::iterator itB = classes.begin();
178 std::set<ScalarType>::iterator itE = classes.end();
179 class_numbers.assign(itB, itE);
183 this->ui->tableWidget->setColumnCount(
186 this->ui->tableWidget->setRowCount(2 + nbClasses);
190 QTableWidgetItem* newItem =
nullptr;
192 this->ui->tableWidget->setSpan(0, 0, 2, 2);
193 this->ui->tableWidget->setSpan(0, 2, 1, nbClasses);
194 this->ui->tableWidget->setSpan(2, 0, nbClasses, 1);
195 this->ui->tableWidget->setSpan(0, 2 + nbClasses, 1, 3);
197 newItem =
new QTableWidgetItem(
"Predicted");
198 newItem->setFont(font);
199 newItem->setBackground(Qt::lightGray);
200 newItem->setTextAlignment(Qt::AlignCenter);
201 this->ui->tableWidget->setItem(0, 2, newItem);
203 newItem =
new QTableWidgetItem(
"Real");
204 newItem->setFont(font);
205 newItem->setBackground(Qt::lightGray);
206 newItem->setTextAlignment(Qt::AlignCenter);
207 this->ui->tableWidget->setItem(2, 0, newItem);
209 newItem =
new QTableWidgetItem(
"Precision");
210 newItem->setToolTip(
"TP / (TP + FP)");
211 newItem->setFont(font);
212 this->ui->tableWidget->setItem(1, 2 + nbClasses +
PRECISION, newItem);
213 newItem =
new QTableWidgetItem(
"Recall");
214 newItem->setToolTip(
"TP / (TP + FN)");
215 newItem->setFont(font);
216 this->ui->tableWidget->setItem(1, 2 + nbClasses +
RECALL, newItem);
217 newItem =
new QTableWidgetItem(
"F1-score");
219 "Harmonic mean of precision and recall (the closer to 1 the "
220 "better)\n2 x precision x recall / (precision + recall)");
221 newItem->setFont(font);
222 this->ui->tableWidget->setItem(1, 2 + nbClasses +
F1_SCORE, newItem);
224 for (
int idx = 0; idx < class_numbers.size(); idx++) {
225 QString str = QString::number(class_numbers[idx]);
226 newItem =
new QTableWidgetItem(str);
227 newItem->setFont(font);
228 this->ui->tableWidget->setItem(1, 2 + idx, newItem);
229 newItem =
new QTableWidgetItem(str);
230 newItem->setFont(font);
231 this->ui->tableWidget->setItem(2 + idx, 1, newItem);
237 for (
int row = 0; row < nbClasses; row++)
238 for (
int column = 0; column < nbClasses; column++) {
239 double val = confusionMatrix.at<
int>(row, column);
240 QTableWidgetItem* newItem =
241 new QTableWidgetItem(QString::number(val));
243 newItem->setBackground(
244 GetColor(val / vec_TP_FN.at<
int>(row, 0), 0, 128, 255));
246 newItem->setBackground(
247 GetColor(val / vec_TP_FN.at<
int>(row, 0), 200, 50, 50));
249 this->ui->tableWidget->setItem(2 + row, +2 + column, newItem);
253 for (
int realIdx = 0; realIdx < nbClasses; realIdx++) {
254 newItem =
new QTableWidgetItem(QString::number(
255 precisionRecallF1Score.at<
float>(realIdx,
PRECISION),
'g', 2));
256 this->ui->tableWidget->setItem(2 + realIdx, 2 + nbClasses +
PRECISION,
258 newItem =
new QTableWidgetItem(QString::number(
259 precisionRecallF1Score.at<
float>(realIdx,
RECALL),
'g', 2));
260 this->ui->tableWidget->setItem(2 + realIdx, 2 + nbClasses +
RECALL,
262 newItem =
new QTableWidgetItem(QString::number(
263 precisionRecallF1Score.at<
float>(realIdx,
F1_SCORE),
'g', 2));
264 this->ui->tableWidget->setItem(2 + realIdx, 2 + nbClasses +
F1_SCORE,
272 label = session +
" / " + QString::number(run);
274 this->ui->label_sessionRun->setText(label);
278 QFile file(filePath);
280 if (!file.open(QIODevice::WriteOnly | QIODevice::Text)) {
285 QTextStream stream(&file);
286 stream <<
"# columns: predicted classes\n# rows: actual classes\n";
287 stream <<
"# last three colums: precision / recall / F1-score\n";
288 for (
auto class_number : class_numbers) {
289 stream << class_number <<
" ";
292 for (
int row = 0; row < confusionMatrix.rows; row++) {
293 stream << class_numbers.at(row) <<
" ";
294 for (
int col = 0; col < confusionMatrix.cols; col++) {
295 stream << confusionMatrix.at<
int>(row, col) <<
" ";
297 stream << precisionRecallF1Score.at<
float>(row,
PRECISION) <<
" ";
298 stream << precisionRecallF1Score.at<
float>(row,
RECALL) <<
" ";
299 stream << precisionRecallF1Score.at<
float>(row,
F1_SCORE)
constexpr ScalarType NAN_VALUE
NaN as a ScalarType value.
static bool Error(const char *format,...)
Display an error dialog with formatted message.
float getOverallAccuracy()
~ConfusionMatrix() override
void computePrecisionRecallF1Score(cv::Mat &matrix, cv::Mat &precisionRecallF1Score, cv::Mat &vec_TP_FN)
void setSessionRun(QString session, int run)
void compute(const std::vector< ScalarType > &actual, const std::vector< ScalarType > &predicted)
ConfusionMatrix(const std::vector< ScalarType > &actual, const std::vector< ScalarType > &predicted)
bool save(QString filePath)
float computeOverallAccuracy(cv::Mat &matrix)
static QColor GetColor(double value, double r1, double g1, double b1)
QTextStream & endl(QTextStream &stream)