R/confusion_matrix.R
confusion_matrix.Rd
Given a vector of predictions and target values, calculate numerous statistics of interest. Modified from m-clark/confusion_matrix.
confusion_matrix(
prediction,
target,
positive = NULL,
prevalence = NULL,
dnn = c("Predicted", "Target"),
longer = FALSE,
...
)
A vector of predictions
A vector of target values
The positive class for a 2-class setting. Default is
NULL
, which will result in using the first level of target
.
Prevalence rate. Default is NULL
.
The row and column headers for the contingency table returned. Default is 'Predicted' for rows and 'Target' for columns.
Transpose the output to long form. Default is FALSE (requires
tidyr 1.0
).
Other parameters, not currently used.
A list of tibble(s) with the associated statistics and possibly the frequency table as list column of the first element. If classes contain >1 numeric class and a single non-numeric class (e.g., "1", "2", "3", and "Unrelated", the RMSE of the reciprocal of the Targets + 0.5 will also be returned.)
This returns accuracy, agreement, and other statistics. See the
functions below to find out more. Originally inspired by the
confusionMatrix
function from the caret
package.
Kuhn, M., & Johnson, K. (2013). Applied predictive modeling.
prediction = c(0,1,1,0,0,1,0,1,1,1)
target = c(0,1,1,1,0,1,0,1,0,1)
confusion_matrix(prediction, target, positive = '1')
#> $Accuracy
#> # A tibble: 1 × 5
#> Accuracy `Accuracy LL` `Accuracy UL` `Accuracy Guessing` `Accuracy P-value`
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0.8 0.444 0.975 0.6 0.167
#>
#> $Other
#> # A tibble: 1 × 17
#> Positive N N Posit…¹ N Neg…² Sensi…³ Speci…⁴ PPV/P…⁵ NPV F1/Di…⁶ Preva…⁷
#> <chr> <int> <int> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 1 10 6 4 0.833 0.75 0.833 0.75 0.833 0.6
#> # … with 7 more variables: `Detection Rate` <dbl>,
#> # `Detection Prevalence` <dbl>, `Balanced Accuracy` <dbl>, FDR <dbl>,
#> # FOR <dbl>, `FPR/Fallout` <dbl>, FNR <dbl>, and abbreviated variable names
#> # ¹`N Positive`, ²`N Negative`, ³`Sensitivity/Recall/TPR`,
#> # ⁴`Specificity/TNR`, ⁵`PPV/Precision`, ⁶`F1/Dice`, ⁷Prevalence
#>
set.seed(42)
prediction = sample(letters[1:4], 250, replace = TRUE, prob = 1:4)
target = sample(letters[1:4], 250, replace = TRUE, prob = 1:4)
confusion_matrix(prediction, target)
#> Reciprocal RMSE not calculated: more than one non-numeric class.
#> $Accuracy
#> # A tibble: 1 × 5
#> Accuracy `Accuracy LL` `Accuracy UL` `Accuracy Guessing` `Accuracy P-value`
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0.276 0.222 0.336 0.452 1.00
#>
#> $Other
#> # A tibble: 5 × 15
#> Class N Sensitiv…¹ Speci…² PPV/P…³ NPV F1/Di…⁴ Preva…⁵ Detec…⁶ Detec…⁷
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 a 27 0.111 0.879 0.1 0.891 0.105 0.108 0.012 0.12
#> 2 b 39 0.154 0.782 0.115 0.833 0.132 0.156 0.024 0.208
#> 3 c 71 0.282 0.709 0.278 0.713 0.280 0.284 0.08 0.288
#> 4 d 113 0.354 0.591 0.417 0.526 0.383 0.452 0.16 0.384
#> 5 Average 62.5 0.225 0.740 0.227 0.741 0.225 0.25 0.069 0.25
#> # … with 5 more variables: `Balanced Accuracy` <dbl>, FDR <dbl>, FOR <dbl>,
#> # `FPR/Fallout` <dbl>, FNR <dbl>, and abbreviated variable names
#> # ¹`Sensitivity/Recall/TPR`, ²`Specificity/TNR`, ³`PPV/Precision`,
#> # ⁴`F1/Dice`, ⁵Prevalence, ⁶`Detection Rate`, ⁷`Detection Prevalence`
#>
#> $Table
#> Target
#> Predicted a b c d
#> a 3 3 6 18
#> b 5 6 21 20
#> c 8 9 20 35
#> d 11 21 24 40
#>
#> $recip_rmse
#> [1] NA
#>
prediction = c(rep(1, 50), rep(2, 40), rep(3, 60))
target = c(rep(1, 50), rep(2, 50), rep(3, 50))
confusion_matrix(prediction, target)
#> $Accuracy
#> # A tibble: 1 × 5
#> Accuracy `Accuracy LL` `Accuracy UL` `Accuracy Guessing` `Accuracy P-value`
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0.933 0.881 0.968 0.333 3.36e-54
#>
#> $Other
#> # A tibble: 4 × 15
#> Class N Sensitiv…¹ Speci…² PPV/P…³ NPV F1/Di…⁴ Preva…⁵ Detec…⁶ Detec…⁷
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 1 50 1 1 1 1 1 0.333 0.333 0.333
#> 2 2 50 0.8 1 1 0.909 0.889 0.333 0.267 0.267
#> 3 3 50 1 0.9 0.833 1 0.909 0.333 0.333 0.4
#> 4 Average 50 0.933 0.967 0.944 0.970 0.933 0.333 0.311 0.333
#> # … with 5 more variables: `Balanced Accuracy` <dbl>, FDR <dbl>, FOR <dbl>,
#> # `FPR/Fallout` <dbl>, FNR <dbl>, and abbreviated variable names
#> # ¹`Sensitivity/Recall/TPR`, ²`Specificity/TNR`, ³`PPV/Precision`,
#> # ⁴`F1/Dice`, ⁵Prevalence, ⁶`Detection Rate`, ⁷`Detection Prevalence`
#>
#> $Table
#> Target
#> Predicted 1 2 3
#> 1 50 0 0
#> 2 0 40 0
#> 3 0 10 50
#>
#> $recip_rmse
#> [1] 0.02950844
#>
confusion_matrix(prediction, target) %>% purrr::pluck("Table")
#> Target
#> Predicted 1 2 3
#> 1 50 0 0
#> 2 0 40 0
#> 3 0 10 50
confusion_matrix(prediction, target, longer=TRUE)
#> $Accuracy
#> # A tibble: 5 × 2
#> Statistic Value
#> <chr> <dbl>
#> 1 Accuracy 9.33e- 1
#> 2 Accuracy LL 8.81e- 1
#> 3 Accuracy UL 9.68e- 1
#> 4 Accuracy Guessing 3.33e- 1
#> 5 Accuracy P-value 3.36e-54
#>
#> $Other
#> # A tibble: 56 × 3
#> Class Statistic Value
#> <chr> <chr> <dbl>
#> 1 1 N 50
#> 2 1 Sensitivity/Recall/TPR 1
#> 3 1 Specificity/TNR 1
#> 4 1 PPV/Precision 1
#> 5 1 NPV 1
#> 6 1 F1/Dice 1
#> 7 1 Prevalence 0.333
#> 8 1 Detection Rate 0.333
#> 9 1 Detection Prevalence 0.333
#> 10 1 Balanced Accuracy 1
#> # … with 46 more rows
#>
#> $Table
#> Target
#> Predicted 1 2 3
#> 1 50 0 0
#> 2 0 40 0
#> 3 0 10 50
#>
#> $recip_rmse
#> [1] 0.02950844
#>
confusion_matrix(prediction, target, longer=TRUE) %>%
purrr::pluck("Other") %>%
tidyr::spread(Class, Value)
#> # A tibble: 14 × 5
#> Statistic `1` `2` `3` Average
#> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 Balanced Accuracy 1 0.9 0.95 0.95
#> 2 Detection Prevalence 0.333 0.267 0.4 0.333
#> 3 Detection Rate 0.333 0.267 0.333 0.311
#> 4 F1/Dice 1 0.889 0.909 0.933
#> 5 FDR 0 0 0.167 0.0556
#> 6 FNR 0 0.2 0 0.0667
#> 7 FOR 0 0.0909 0 0.0303
#> 8 FPR/Fallout 0 0 0.1 0.0333
#> 9 N 50 50 50 50
#> 10 NPV 1 0.909 1 0.970
#> 11 PPV/Precision 1 1 0.833 0.944
#> 12 Prevalence 0.333 0.333 0.333 0.333
#> 13 Sensitivity/Recall/TPR 1 0.8 1 0.933
#> 14 Specificity/TNR 1 1 0.9 0.967
# Prediction with an unrelated class
prediction = c(rep(1, 50), rep(2, 40), rep(3, 60), rep("Unrelated", 55))
target = c(rep(1, 50), rep(2, 50), rep(3, 55), rep("Unrelated", 50))
confusion_matrix(prediction, target)
#> $Accuracy
#> # A tibble: 1 × 5
#> Accuracy `Accuracy LL` `Accuracy UL` `Accuracy Guessing` `Accuracy P-value`
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0.927 0.882 0.958 0.268 5.56e-89
#>
#> $Other
#> # A tibble: 5 × 15
#> Class N Sensit…¹ Speci…² PPV/P…³ NPV F1/Di…⁴ Preva…⁵ Detec…⁶ Detec…⁷
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 1 50 1 1 1 1 1 0.244 0.244 0.244
#> 2 2 50 0.8 1 1 0.939 0.889 0.244 0.195 0.195
#> 3 3 55 0.909 0.933 0.833 0.966 0.870 0.268 0.244 0.293
#> 4 Unrelated 50 1 0.968 0.909 1 0.952 0.244 0.244 0.268
#> 5 Average 51.2 0.927 0.975 0.936 0.976 0.928 0.25 0.232 0.25
#> # … with 5 more variables: `Balanced Accuracy` <dbl>, FDR <dbl>, FOR <dbl>,
#> # `FPR/Fallout` <dbl>, FNR <dbl>, and abbreviated variable names
#> # ¹`Sensitivity/Recall/TPR`, ²`Specificity/TNR`, ³`PPV/Precision`,
#> # ⁴`F1/Dice`, ⁵Prevalence, ⁶`Detection Rate`, ⁷`Detection Prevalence`
#>
#> $Table
#> Target
#> Predicted 1 2 3 Unrelated
#> 1 50 0 0 0
#> 2 0 40 0 0
#> 3 0 10 50 0
#> Unrelated 0 0 5 50
#>
#> $recip_rmse
#> [1] 0.02711929
#>
# Prediction with two unrelated classes
prediction = c(rep(1, 50), rep(2, 40), rep("Third", 60), rep("Unrelated", 55))
target = c(rep(1, 50), rep(2, 50), rep("Third", 55), rep("Unrelated", 50))
confusion_matrix(prediction, target)
#> Reciprocal RMSE not calculated: more than one non-numeric class.
#> $Accuracy
#> # A tibble: 1 × 5
#> Accuracy `Accuracy LL` `Accuracy UL` `Accuracy Guessing` `Accuracy P-value`
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0.927 0.882 0.958 0.268 5.56e-89
#>
#> $Other
#> # A tibble: 5 × 15
#> Class N Sensit…¹ Speci…² PPV/P…³ NPV F1/Di…⁴ Preva…⁵ Detec…⁶ Detec…⁷
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 1 50 1 1 1 1 1 0.244 0.244 0.244
#> 2 2 50 0.8 1 1 0.939 0.889 0.244 0.195 0.195
#> 3 Third 55 0.909 0.933 0.833 0.966 0.870 0.268 0.244 0.293
#> 4 Unrelated 50 1 0.968 0.909 1 0.952 0.244 0.244 0.268
#> 5 Average 51.2 0.927 0.975 0.936 0.976 0.928 0.25 0.232 0.25
#> # … with 5 more variables: `Balanced Accuracy` <dbl>, FDR <dbl>, FOR <dbl>,
#> # `FPR/Fallout` <dbl>, FNR <dbl>, and abbreviated variable names
#> # ¹`Sensitivity/Recall/TPR`, ²`Specificity/TNR`, ³`PPV/Precision`,
#> # ⁴`F1/Dice`, ⁵Prevalence, ⁶`Detection Rate`, ⁷`Detection Prevalence`
#>
#> $Table
#> Target
#> Predicted 1 2 Third Unrelated
#> 1 50 0 0 0
#> 2 0 40 0 0
#> Third 0 10 50 0
#> Unrelated 0 0 5 50
#>
#> $recip_rmse
#> [1] NA
#>