From 84db4dad12c60c4ccb760e1d19cf1d002a27b93c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=B6nig?= <ye87zine@usr.idiv.de> Date: Thu, 21 Nov 2024 21:37:50 +0100 Subject: [PATCH] implemented multiclass model. Multiple improvements and bugfixes across pipeline --- R/02_02_functional_traits_preparation.R | 38 ++++- R/04_01_ssdm_modeling.R | 6 +- ..._02_msdm_modeling.R => 04_02_msdm_embed.R} | 16 +-- R/04_03_msdm_embed_informed.R | 135 ++++++++++++++++++ ..._with_traits.R => 04_04_msdm_multiclass.R} | 60 +++----- R/05_performance_analysis.qmd | 93 +++++++----- R/utils.R | 53 +++++++ 7 files changed, 309 insertions(+), 92 deletions(-) rename R/{04_02_msdm_modeling.R => 04_02_msdm_embed.R} (84%) create mode 100644 R/04_03_msdm_embed_informed.R rename R/{04_03_msdm_modeling_with_traits.R => 04_04_msdm_multiclass.R} (53%) diff --git a/R/02_02_functional_traits_preparation.R b/R/02_02_functional_traits_preparation.R index 80f5887..9d66867 100644 --- a/R/02_02_functional_traits_preparation.R +++ b/R/02_02_functional_traits_preparation.R @@ -36,7 +36,41 @@ library("vegan") load("data/r_objects/range_maps.RData") load("data/r_objects/traits_matched.RData") target_species = unique(range_maps$name_matched[!is.na(range_maps$name_matched)]) -traits_target = dplyr::filter(traits_matched, species %in% target_species) + +traits_target = traits_matched %>% + dplyr::filter(species %in% !!target_species) %>% + mutate(match_genus = stringr::str_detect(species, Genus), + match_epithet = stringr::str_detect(species, Species)) %>% + dplyr::group_by(species) %>% + dplyr::group_modify(~ { + if (nrow(.x) == 1) { + return (.x) + } + + x_mod = .x + print(x_mod) + while (nrow(x_mod) > 1){ + if(any(isFALSE(x_mod$match_genus))){ + x_mod = dplyr::filter(x_mod, isTRUE(match_genus)) + next + } + + if(any(isFALSE(x_mod$match_species))){ + x_mod = dplyr::filter(x_mod, isTRUE(match_species)) + next + } + + x_mod = x_mod[1,] + } + + if(nrow(x_mod) == 1){ + return(x_mod) + } else { + return(.x[1,]) + } + + }) + # some pre-processing traits_target$`ForStrat.Value.Proc` = as.numeric(factor(traits_target$`ForStrat.Value`, levels = c("G", "S", "Ar", "A"))) @@ -54,6 +88,6 @@ activity_dist = vegan::vegdist(traits_target[,activity_cols], "bray") bodymass_dist = dist(traits_target[,bodymass_cols]) func_dist = (diet_dist + foraging_dist/max(foraging_dist) + activity_dist + bodymass_dist/max(bodymass_dist)) / 4 -names(func_dist) = stringr::str_replace_all(traits_target$species, pattern = " ", "_") +names(func_dist) = traits_target$species save(func_dist, file = "data/r_objects/func_dist.RData") diff --git a/R/04_01_ssdm_modeling.R b/R/04_01_ssdm_modeling.R index bdc7c9f..959f77c 100644 --- a/R/04_01_ssdm_modeling.R +++ b/R/04_01_ssdm_modeling.R @@ -114,7 +114,7 @@ ssdm_results = furrr::future_map(data_split, .options = furrr::furrr_options(see nn_fit = dnn( X = train_data[, predictors], Y = train_data$present, - hidden = c(500L, 200L, 50L), + hidden = c(200L, 200L, 200L), loss = "binomial", activation = c("sigmoid", "leaky_relu", "leaky_relu"), epochs = 500L, @@ -139,7 +139,7 @@ ssdm_results = furrr::future_map(data_split, .options = furrr::furrr_options(see performance_summary = tibble( species = data_spec$species[1], obs = nrow(data_spec), - model = c("RF", "GBM", "GLM", "NN"), + model = c("SSDM_RF", "SSDM_GBM", "SSDM_GLM", "SSDM_NN"), auc = c(rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC, nn_performance$AUC), accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy, nn_performance$Accuracy), kappa = c(rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa, nn_performance$Kappa), @@ -153,4 +153,4 @@ ssdm_results = furrr::future_map(data_split, .options = furrr::furrr_options(see ssdm_results = bind_rows(ssdm_results) -save(ssdm_results, file = "data/r_objects/ssdm_results.RData") \ No newline at end of file +save(ssdm_results, file = "data/r_objects/ssdm_results.RData") diff --git a/R/04_02_msdm_modeling.R b/R/04_02_msdm_embed.R similarity index 84% rename from R/04_02_msdm_modeling.R rename to R/04_02_msdm_embed.R index a9e5eea..6c825fd 100644 --- a/R/04_02_msdm_modeling.R +++ b/R/04_02_msdm_embed.R @@ -41,18 +41,8 @@ save(msdm_fit, file = "data/r_objects/msdm_fit2.RData") # ----------------------------------------------------------------------# # Evaluate model #### # ----------------------------------------------------------------------# -load("data/r_objects/msdm_fit2.RData") +load("data/r_objects/msdm_fit.RData") -# Overall -preds_train = predict(msdm_fit, newdata = as.matrix(train_data), type = "response") -preds_test = predict(msdm_fit, newdata = as.matrix(test_data), type = "response") - -hist(preds_train) -hist(preds_test) - -eval_overall = evaluate_model(msdm_fit, test_data) - -# Per species data_split = split(model_data, model_data$species) msdm_results = lapply(data_split, function(data_spec){ @@ -68,7 +58,7 @@ msdm_results = lapply(data_split, function(data_spec){ performance_summary = tibble( species = data_spec$species[1], obs = nrow(data_spec), - model = "NN_MSDM", + model = "MSDM_embed", auc = msdm_performance$AUC, accuracy = msdm_performance$Accuracy, kappa = msdm_performance$Kappa, @@ -78,4 +68,4 @@ msdm_results = lapply(data_split, function(data_spec){ ) }) %>% bind_rows() -save(msdm_results, file = "data/r_objects/msdm_results2.RData") +save(msdm_results, file = "data/r_objects/msdm_results.RData") diff --git a/R/04_03_msdm_embed_informed.R b/R/04_03_msdm_embed_informed.R new file mode 100644 index 0000000..718a8e9 --- /dev/null +++ b/R/04_03_msdm_embed_informed.R @@ -0,0 +1,135 @@ +library(dplyr) +library(tidyr) +library(cito) + +source("R/utils.R") + +load("data/r_objects/model_data.RData") +load("data/r_objects/func_dist.RData") + +# ----------------------------------------------------------------------# +# Prepare data #### +# ----------------------------------------------------------------------# +model_species = intersect(model_data$species, names(func_dist)) + +model_data_final = model_data %>% + dplyr::filter(species %in% !!model_species) %>% + # dplyr::mutate_at(vars(starts_with("layer")), scale) %>% # Scaling seems to make things worse often + dplyr::mutate(species_int = as.integer(as.factor(species))) + +train_data = dplyr::filter(model_data_final, train == 1) +test_data = dplyr::filter(model_data_final, train == 0) + +sp_ind = match(model_species, names(func_dist)) +func_dist = as.matrix(func_dist)[sp_ind, sp_ind] + +embeddings = eigen(func_dist)$vectors[,1:20] +predictors = paste0("layer_", 1:19) + +# ----------------------------------------------------------------------# +# Without training the embedding #### +# ----------------------------------------------------------------------# +# 1. Train +formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, weights = embeddings, lambda = 0.00001, train = F)")) +msdm_fit_embedding_untrained = dnn( + formula, + data = train_data, + hidden = c(200L, 200L, 200L), + loss = "binomial", + activation = c("sigmoid", "leaky_relu", "leaky_relu"), + epochs = 12000L, + lr = 0.01, + batchsize = nrow(train_data)/5, + dropout = 0.1, + burnin = 100, + optimizer = config_optimizer("adam", weight_decay = 0.001), + lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7), + early_stopping = 250, + validation = 0.3, + device = "cuda" +) +save(msdm_fit_embedding_untrained, file = "data/r_objects/msdm_fit_embedding_untrained.RData") + +# 2. Evaluate +# Per species +data_split = test_data %>% + group_by(species_int) %>% + group_split() + +msdm_results_embedding_untrained = lapply(data_split, function(data_spec){ + target_species = data_spec$species[1] + data_spec = dplyr::select(data_spec, -species) + + msdm_performance = tryCatch({ + evaluate_model(msdm_fit_embedding_untrained, data_spec) + }, error = function(e){ + list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA) + }) + + performance_summary = tibble( + species = !!target_species, + obs = length(which(model_data$species == target_species)), + model = "MSDM_embed_informed_untrained", + auc = msdm_performance$AUC, + accuracy = msdm_performance$Accuracy, + kappa = msdm_performance$Kappa, + precision = msdm_performance$Precision, + recall = msdm_performance$Recall, + f1 = msdm_performance$F1 + ) +}) %>% bind_rows() + +save(msdm_results_embedding_untrained, file = "data/r_objects/msdm_results_embedding_untrained.RData") + +# -------------------------------------------------------------------# +# With training the embedding #### +# ------------------------------------------------------------ ------# +formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, weights = embeddings, lambda = 0.00001, train = T)")) +msdm_fit_embedding_trained = dnn( + formula, + data = train_data, + hidden = c(200L, 200L, 200L), + loss = "binomial", + activation = c("sigmoid", "leaky_relu", "leaky_relu"), + epochs = 12000L, + lr = 0.01, + batchsize = nrow(train_data)/5, + dropout = 0.1, + burnin = 100, + optimizer = config_optimizer("adam", weight_decay = 0.001), + lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7), + early_stopping = 250, + validation = 0.3, + device = "cuda" +) +save(msdm_fit_embedding_trained, file = "data/r_objects/msdm_fit_embedding_trained.RData") + +# 2. Evaluate +data_split = test_data %>% + group_by(species_int) %>% + group_split() + +msdm_results_embedding_trained = lapply(data_split, function(data_spec){ + target_species = data_spec$species[1] + data_spec = dplyr::select(data_spec, -species) + + msdm_performance = tryCatch({ + evaluate_model(msdm_fit_embedding_trained, data_spec) + }, error = function(e){ + list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA) + }) + + performance_summary = tibble( + species = !!target_species, + obs = length(which(model_data$species == target_species)), + model = "MSDM_embed_informed_trained", + auc = msdm_performance$AUC, + accuracy = msdm_performance$Accuracy, + kappa = msdm_performance$Kappa, + precision = msdm_performance$Precision, + recall = msdm_performance$Recall, + f1 = msdm_performance$F1 + ) +}) %>% bind_rows() + +save(msdm_results_embedding_trained, file = "data/r_objects/msdm_results_embedding_trained.RData") diff --git a/R/04_03_msdm_modeling_with_traits.R b/R/04_04_msdm_multiclass.R similarity index 53% rename from R/04_03_msdm_modeling_with_traits.R rename to R/04_04_msdm_multiclass.R index 3495629..b2af99b 100644 --- a/R/04_03_msdm_modeling_with_traits.R +++ b/R/04_04_msdm_multiclass.R @@ -5,33 +5,29 @@ library(cito) source("R/utils.R") load("data/r_objects/model_data.RData") -load("data/r_objects/func_dist.RData") # ----------------------------------------------------------------------# -# Prepare embeddings #### +# Prepare data #### # ----------------------------------------------------------------------# -model_data$species_int = as.integer(as.factor(model_data$species)) +train_data = dplyr::filter(model_data, present == 1, train == 1) # Use only presences for training +test_data = dplyr::filter(model_data, train == 0) # Evaluate on presences + pseudo-absences for comparability with binary models -train_data = dplyr::filter(model_data, train == 1) -test_data = dplyr::filter(model_data, train == 0) +predictors = paste0("layer_", 1:19) # ----------------------------------------------------------------------# -# Train model #### +# Train model #### # ----------------------------------------------------------------------# -predictors = paste0("layer_", 1:19) -formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, dim = 20, lambda = 0.000001)")) - -msdm_fit = dnn( +formula = as.formula(paste0("species ~ ", paste(predictors, collapse = '+'))) +msdm_fit_multiclass = dnn( formula, data = train_data, - hidden = c(500L, 1000L, 1000L), - loss = "binomial", - activation = c("sigmoid", "leaky_relu", "leaky_relu"), - epochs = 10000L, - lr = 0.01, - baseloss = 1, + hidden = c(200L, 200L, 200L), + loss = "softmax", + activation = c("leaky_relu", "leaky_relu", "leaky_relu"), + epochs = 5000L, + lr = 0.01, batchsize = nrow(train_data)/5, - dropout = 0.25, + dropout = 0.1, burnin = 100, optimizer = config_optimizer("adam", weight_decay = 0.001), lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7), @@ -40,38 +36,26 @@ msdm_fit = dnn( device = "cuda" ) -save(msdm_fit, file = "data/r_objects/msdm_fit.RData") +save(msdm_fit_multiclass, file = "data/r_objects/msdm_fit_multiclass.RData") # ----------------------------------------------------------------------# # Evaluate model #### # ----------------------------------------------------------------------# -load("data/r_objects/msdm_fit.RData") - -# Overall -preds_train = predict(msdm_fit, newdata = as.matrix(train_data), type = "response") -preds_test = predict(msdm_fit, newdata = as.matrix(test_data), type = "response") +data_split = test_data %>% + group_by(species) %>% + group_split() -hist(preds_train) -hist(preds_test) - -eval_overall = evaluate_model(msdm_fit, test_data) - -# Per species -data_split = split(model_data, model_data$species) - -msdm_results = lapply(data_split, function(data_spec){ - test_data = dplyr::filter(data_spec, train == 0) - +msdm_results_multiclass = lapply(data_split, function(data_spec){ msdm_performance = tryCatch({ - evaluate_model(msdm_fit, test_data) + evaluate_multiclass_model(msdm_fit_multiclass, data_spec, k = 10) # Top-k accuracy }, error = function(e){ list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA) }) performance_summary = tibble( species = data_spec$species[1], - obs = nrow(data_spec), - model = "NN_MSDM", + obs = nrow(dplyr::filter(model_data, species == data_spec$species[1])), + model = "MSDM_multiclass", auc = msdm_performance$AUC, accuracy = msdm_performance$Accuracy, kappa = msdm_performance$Kappa, @@ -81,4 +65,4 @@ msdm_results = lapply(data_split, function(data_spec){ ) }) %>% bind_rows() -save(msdm_results, file = "data/r_objects/msdm_results.RData") +save(msdm_results_multiclass, file = "data/r_objects/msdm_results_multiclass.RData") diff --git a/R/05_performance_analysis.qmd b/R/05_performance_analysis.qmd index 7eb33f1..ca81a81 100644 --- a/R/05_performance_analysis.qmd +++ b/R/05_performance_analysis.qmd @@ -7,15 +7,16 @@ engine: knitr ```{r init, echo = FALSE, include = FALSE} library(tidyverse) -library(Symobio) library(sf) library(plotly) library(DT) load("../data/r_objects/ssdm_results.RData") -load("../data/r_objects/msdm_fit2.RData") -load("../data/r_objects/msdm_results2.RData") +load("../data/r_objects/msdm_fit.RData") +load("../data/r_objects/msdm_results.RData") +load("../data/r_objects/msdm_results_embedding_trained.RData.") +load("../data/r_objects/msdm_results_multiclass.RData.") load("../data/r_objects/range_maps.RData") load("../data/r_objects/range_maps_gridded.RData") load("../data/r_objects/occs_final.RData") @@ -25,7 +26,7 @@ sf::sf_use_s2(use_s2 = FALSE) ```{r globals, echo = FALSE, include = FALSE} # Select metrics -focal_metrics = c("auc", "f1", "accuracy") # There's a weird bug in plotly that scrambles up lines when using more then three groups +focal_metrics = c("auc", "f1", "accuracy") # There's a weird bug in plotly that scrambles up lines when using more than three groups # Dropdown options plotly_buttons = list() @@ -52,72 +53,92 @@ lin_regression = function(x, y){ ) } +# Performance table +performance = bind_rows(ssdm_results, msdm_results, msdm_results_embedding_trained, msdm_results_multiclass) %>% + pivot_longer(c(auc, accuracy, kappa, precision, recall, f1), names_to = "metric") %>% + dplyr::filter(!is.na(value)) %>% + dplyr::mutate( + metric = factor(metric, levels = c("auc", "kappa", "f1", "accuracy", "precision", "recall")), + value = round(pmax(value, 0, na.rm = T), 3) # Fix one weird instance of f1 < 0 + ) %>% + dplyr::filter(metric %in% focal_metrics) ``` ## Summary -This document summarizes the performance of four SDM algorithms (Random Forest, Gradient Boosting Machine, Generalized Linear Model, Deep Neural Network) for `r I(length(unique(ssdm_results$species)))` South American mammal species. We use `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) to evaluate model performance and look at how performance varies with five factors (number of records, range size, range coverage, range coverage bias, and functional group). +This document summarizes the performance of different sSDM and mMSDM algorithms for `r I(length(unique(performance$species)))` South American mammal species. Model performance is evaluated on `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) and analyzed along five potential influence factors (number of records, range size, range coverage, range coverage bias, and functional group). The comparison of sSDM vs mSDM approaches is of particular interest. -### Modeling decisions: +Code can be found on [GitLab](https://git.idiv.de/ye87zine/symobio-modeling). -#### Data +### Modeling overview: + +#### General decisions - Randomly sampled pseudo-absences from expanded area of extent of occurrence records (×1.25) - Balanced presences and absences for each species - Predictors: all 19 CHELSA bioclim variables -- 70/30 Split of training vs. test data +- 70/30 Split of training vs. test data (except for NN models) -#### Algorithms +#### sSDM Algorithms -Random Forest +Random Forest (**SSDM_RF**) -- Spatial block cross-validation during training - Hyperparameter tuning of `mtry` +- Spatial block cross-validation during training -Generalized boosted machine +Generalized boosted machine (**SSDM_GBM**) -- Spatial block cross-validation during training - Hyperparameter tuning across `n.trees` , `interaction.depth` , `shrinkage`, `n.minobsinnode` +- Spatial block cross-validation during training -Generalized Linear Model +Generalized Linear Model (**SSDM_GLM**) -- Spatial block cross-validation during training - Logistic model with binomial link function +- Spatial block cross-validation during training -Neural Netwok (single-species) +Neural Netwok (**SSDM_NN**) -- Three hidden layers (sigmoid - 500, leaky_relu - 200, leaky_relu - 50) -- Binomial loss, ADAM optimizer -- Poor convergence +- Three hidden layers, leaky ReLu activations, binomial loss +- no spatial block cross-validation during training -Neural Network (multi-species) +#### mSDM Algorithms -- Three hidden layers (sigmoid - 500, leaky_relu - 1000, leaky_relu - 1000) -- Binomial loss, ADAM optimizer -- very slow convergence (`r I(length(which(!is.na(msdm_fit$losses$train_l))))` epochs) +Binary Neural Network with species embedding (**MSDM_embed**) + +- definition: presence \~ environment + embedding(species) +- prediction: probability of occurrence given a set of (environmental) inputs and species identity +- embedding initialized at random +- three hidden layers, sigmoid + leaky ReLu activations, binomial loss + +Binary Neural Network with trait-informed species embedding (**MSDM_embed_traits**) + +- definition: presence \~ environment + embedding(species) +- prediction: probability of occurrence given a set of (environmental) inputs and species identity +- embedding initialized using eigenvectors of functional distance matrix +- three hidden layers, sigmoid + leaky ReLu activations, binomial loss + +Multi-Class Neural Network (**MSDM_multiclass**) + +- definition: species identity \~ environment +- prediction: probability distribution across all observed species given a set of (environmental) inputs +- presence-only data in training +- three hidden layers, leaky ReLu activations, softmax loss +- Top-k based evaluation (k=10, P/A \~ target species in / not among top 10 predictions) ### Key findings: -- Performance: RF \> GBM \> GLM \> NN -- Convergence problems with Neural Notwork Models +- sSDM algorithms (RF, GBM) outperformed mSDMs in most cases +- mSDMs showed indications of better performance for rare species (\< 10-20 occurrences) - More occurrence records and larger range sizes tended to improve model performance -- Higher range coverage correlated with better performance. +- Higher range coverage correlated with better performance - Range coverage bias and functional group showed some impact but were less consistent +- Convergence problems hampered NN sSDM performance ## Analysis The table below shows the analysed modeling results. ```{r performance, echo = FALSE, message=FALSE, warnings=FALSE} -performance = bind_rows(ssdm_results, msdm_results) %>% - pivot_longer(c(auc, accuracy, kappa, precision, recall, f1), names_to = "metric") %>% - dplyr::filter(!is.na(value)) %>% - dplyr::mutate( - metric = factor(metric, levels = c("auc", "kappa", "f1", "accuracy", "precision", "recall")), - value = round(pmax(value, 0, na.rm = T), 3) # Fix one weird instance of f1 < 0 - ) %>% - dplyr::filter(metric %in% focal_metrics) - DT::datatable(performance) ``` @@ -508,7 +529,7 @@ bslib::card(plot, full_screen = T) Functional groups were assigned based on taxonomic order. The following groupings were used: | Functional group | Taxomic orders | -|------------------------|-----------------------------------------------| +|-----------------------|-----------------------------------------------------------------------| | large ground-dwelling | Carnivora, Artiodactyla, Cingulata, Perissodactyla | | small ground-dwelling | Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha | | arboreal | Primates, Pilosa | diff --git a/R/utils.R b/R/utils.R index edde20d..b1da26a 100644 --- a/R/utils.R +++ b/R/utils.R @@ -67,3 +67,56 @@ evaluate_model <- function(model, test_data) { ) ) } + +evaluate_multiclass_model <- function(model, test_data, k) { + # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances. + # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN) + + # Precision: The proportion of true positives out of all instances predicted as positive. + # Formula: Precision = TP / (TP + FP) + + # Recall (Sensitivity): The proportion of true positives out of all actual positive instances. + # Formula: Recall = TP / (TP + FN) + + # F1 Score: The harmonic mean of Precision and Recall, balancing the two metrics. + # Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall) + target_species = unique(test_data$species) + checkmate::assert_character(target_species, len = 1, any.missing = F) + + new_data = dplyr::select(test_data, -species) + + # Predict probabilities + if(class(model) %in% c("citodnn", "citodnnBootstrap")){ + preds_overall = predict(model, as.matrix(new_data), type = "response") + probs <- as.vector(preds_overall[,target_species]) + + rank = apply(preds_overall, 1, function(x){ # Top-K approach + x_sort = sort(x, decreasing = T) + return(which(names(x_sort) == target_species)) + }) + top_k = as.character(as.numeric(rank <= k)) + preds <- factor(top_k, levels = c("0", "1"), labels = c("A", "P")) + } else { + stop("Unsupported model type: ", class(model)) + } + + actual <- factor(test_data$present, levels = c("0", "1"), labels = c("A", "P")) + + # Calculate AUC + auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc + + # Calculate confusion matrix + cm <- caret::confusionMatrix(preds, actual, positive = "P") + + # Return metrics + return( + list( + AUC = as.numeric(auc), + Accuracy = cm$overall["Accuracy"], + Kappa = cm$overall["Kappa"], + Precision = cm$byClass["Precision"], + Recall = cm$byClass["Recall"], + F1 = cm$byClass["F1"] + ) + ) +} -- GitLab