From 643e472f9256a6ded5e0bd69e8da8a5d06aea9b0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=B6nig?= <ye87zine@usr.idiv.de>
Date: Thu, 20 Feb 2025 16:01:03 +0100
Subject: [PATCH] additional model evaluation and analysis

---
 R/04_01_modelling_ssdm.R       |  8 ++--
 R/04_05_modelling_msdm_rf.R    |  2 +-
 R/05_02_publication_analysis.R | 78 +++++++++++++++++++++++++-------
 R/utils.R                      | 83 ++++++++++++----------------------
 4 files changed, 96 insertions(+), 75 deletions(-)

diff --git a/R/04_01_modelling_ssdm.R b/R/04_01_modelling_ssdm.R
index e5240db..d443309 100644
--- a/R/04_01_modelling_ssdm.R
+++ b/R/04_01_modelling_ssdm.R
@@ -87,16 +87,16 @@ for(pa_spec in data_split){
     })
     
     ## Gradient Boosted Machine ####
-    xgb_performance = tryCatch({
-      xgb_fit = train(
+    gbm_performance = tryCatch({
+      gbm_fit = train(
         x = data_train[, predictors],
         y = data_train$present_fct,
-        method = "xgbTree",
+        method = "gbm",
         trControl = train_ctrl,
         tuneLength = 8,
         verbose = F
       )
-      evaluate_model(xgb_fit, data_test)
+      evaluate_model(gbm_fit, data_test)
     }, error = function(e){
       na_performance
     })
diff --git a/R/04_05_modelling_msdm_rf.R b/R/04_05_modelling_msdm_rf.R
index 29819aa..5e13f48 100644
--- a/R/04_05_modelling_msdm_rf.R
+++ b/R/04_05_modelling_msdm_rf.R
@@ -89,7 +89,7 @@ save(rf_fit, file = "data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
 # Evaluate model                                                     ####
 # ----------------------------------------------------------------------#
 msdm_rf_performance = lapply(1:5, function(fold){
-  load(paste0("data/r_objects/msdm_rf/rf_fit_fold", fold, ".RData"))
+  load(paste0("data/r_objects/msdm_rf/msdm_rf_fit_fold", fold, ".RData"))
   
   test_data =dplyr::filter(model_data, fold_eval == fold) %>% 
     sf::st_drop_geometry()
diff --git a/R/05_02_publication_analysis.R b/R/05_02_publication_analysis.R
index ee481d3..f1b9d26 100644
--- a/R/05_02_publication_analysis.R
+++ b/R/05_02_publication_analysis.R
@@ -154,7 +154,7 @@ library(cito)
 library(ranger)
 
 # Define plotting function
-plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", "gbm", "rf", "nn", "msdm_embed", "msdm_onehot")){
+plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", "gbm", "rf", "nn", "msdm_embed", "msdm_onehot", "msdm_rf")){
   # Species data
   load("data/r_objects/range_maps.RData")
   
@@ -166,7 +166,7 @@ plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam",
   
   # Extract raster values into df
   bbox_spec = sf::st_bbox(pa_spec) %>% 
-    expand_bbox(expansion = 0.75)
+    expand_bbox(expansion = 0.5)
   
   raster_crop = terra::crop(raster_data, bbox_spec)
   new_data = raster_crop %>% 
@@ -185,6 +185,9 @@ plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam",
       load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
       probabilities = predict(msdm_embed_fit, new_data, type = "response")[,1]
       predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
+    } else if(algorithm == "msdm_rf"){
+      load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
+      predictions = predict(rf_fit, new_data, type = "raw")
     } else if(algorithm == "nn") {
       load(paste0("data/r_objects/ssdm_results/full_models/", spec, "_nn_fit.RData"))
       new_data_tmp = dplyr::select(new_data, -species)
@@ -218,46 +221,89 @@ raster_data = terra::rast(raster_filepaths) %>%
   terra::crop(sf::st_bbox(sa_polygon)) %>% 
   terra::project(sf::st_crs(model_data)$input)
 
-specs = sort(sample(levels(model_data$species), 20))
+specs = sort(sample(levels(model_data$species), 30))
 for(spec in specs){
-  pdf(file = paste0("plots/range_predictions/", spec, ".pdf"))
+  pdf(file = paste0("plots/range_predictions/", spec, " (msdm).pdf"))
   tryCatch({
-    plot_predictions(spec, model_data, raster_data, algorithms = c("gam", "xgb", "rf", "msdm_embed", "msdm_onehot"))
+    plot_predictions(spec, model_data, raster_data, algorithms = c("msdm_embed", "msdm_onehot", "msdm_rf"))
   }, finally = {
     dev.off()
   })
 }
 
 # ------------------------------------------------------------------ #
-# 4. Compare msdm predictions (embed vs. onehot)                  ####
+# 4. Compare predictions across species                           ####
 # ------------------------------------------------------------------ #
 load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
 load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
+load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
 
-plot_embed_vs_onehot = function(){
+compare_species_predictions = function(model, sample_size){
+  specs = sample(unique(model_data$species), 2, replace = F)
+  df1 = spatSample(raster_data, sample_size, replace = F, as.df = T) %>% 
+    drop_na() %>% 
+    dplyr::mutate(species = specs[1])
+  df2= df1 %>% 
+    dplyr::mutate(species = specs[2])
+  
+  p1 = predict_new(model, df1)
+  p2 = predict_new(model, df2)
+  
+  plot(x = p1, y = p2, 
+       xlab = df1$species[1], 
+       ylab = df2$species[1], xlim=c(0,1), ylim = c(0,1), pch = 16, 
+       main = deparse(substitute(model)))
+  abline(a = 0, b = 1)
+  text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red")
+  mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8)
+  abline(h=0.5)
+  abline(v=0.5)
+}
+
+compare_species_predictions(msdm_embed_fit, 500) 
+compare_species_predictions(msdm_onehot_fit, 500) 
+compare_species_predictions(rf_fit, 500) 
+
+## --> Predictions for different species are weakly/moderately correlated in NN models (makes sense)
+## --> Predictions for different species are always highly correlated in RF model (seems problematioc)
+
+# ------------------------------------------------------------------ #
+# 5. Compare predictions across models                            ####
+# ------------------------------------------------------------------ #
+load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
+load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
+load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
+
+compare_model_predictions = function(model1, model2, sample_size){
   spec = sample(unique(model_data$species), 1)
-  new_data = spatSample(raster_data, 100, replace = F, as.df = T) %>% 
+  new_data = spatSample(raster_data, sample_size, replace = F, as.df = T) %>% 
     drop_na() %>% 
     dplyr::mutate(species = spec)
   
-  p1_em = predict(msdm_embed_fit, new_data, type = "response")
-  p1_oh = predict(msdm_onehot_fit, new_data, type = "response")
+  p1 = predict_new(model1, new_data)
+  p2 = predict_new(model2, new_data)
   
   # Compare onehot vs embed
-  plot(x = p1_em, y = p1_oh, xlab = "embed", ylab = "onehot", xlim=c(0,1), ylim = c(0,1), pch = 16, main = spec)
+  plot(x = p1, y = p2, 
+       xlab = deparse(substitute(model1)), 
+       ylab = deparse(substitute(model2)), 
+       xlim=c(0,1), ylim = c(0,1), pch = 16, main = spec)
   abline(a = 0, b = 1)
   text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red")
+  mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8)
   abline(h=0.5)
   abline(v=0.5)
 }
 
-plot_embed_vs_onehot()
+compare_model_predictions(msdm_embed_fit, msdm_onehot_fit, 500) 
+compare_model_predictions(msdm_embed_fit, rf_fit, 500)
+compare_model_predictions(msdm_onehot_fit, rf_fit, 500) 
 
-## --> Predictions are similar between msdms
-## --> Onehot model seems to predict slightly more presences than embedded model (???)
+## --> Predictions for the same species from NN_embed and NN_onehot are moderately/strongly correlated (makes sense)
+## --> Predictions for the same species from NN and RF are weakly/not correlated (seems problematic)
 
 # ------------------------------------------------------------------ #
-# 5. Compare species embeddings across folds                      ####
+# 6. Compare species embeddings across folds                      ####
 # ------------------------------------------------------------------ #
 obs_count = model_data %>% 
   sf::st_drop_geometry() %>% 
@@ -359,7 +405,7 @@ ggplot(data = df_plot, aes(x = obs, y = value)) +
 ## --> Some potential for informed embeddings with respect to rare species?
 
 # ------------------------------------------------------------------ #
-# 6. Analyse embedding of full model                              ####
+# 7. Analyse embedding of full model                              ####
 # ------------------------------------------------------------------ #
 load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
 species_lookup = data.frame(species = levels(model_data$species)) %>% 
diff --git a/R/utils.R b/R/utils.R
index b433a44..1f9ec18 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -25,6 +25,35 @@ expand_bbox <- function(bbox, min_span = 1, expansion = 0.25) {
   return(bbox)
 }
 
+predict_new = function(model, data, type = "prob"){
+  stopifnot(type %in% c("prob", "class"))
+  
+  if(class(model) %in% c("citodnn", "citodnnBootstrap")){
+    probs = predict(model, data, type = "response")[,1]
+    if(type == "prob"){
+      return(probs)
+    } else {
+      preds = factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
+      return(preds)
+    }
+  } else {
+    probs = predict(model, data, type = "prob")$P
+    if(type == "prob"){
+      return(probs)
+    } else {
+      preds = predict(model, data, type = "raw")
+    }
+  }
+}
+
+if(class(model) %in% c("citodnn", "citodnnBootstrap")){
+  p1 = predict(model, df1, type = "response")[,1]
+  p2 = predict(model, df2, type = "response")[,1]
+} else {
+  p1 = predict(model, df1, type = "prob")$P
+  p2 = predict(model, df2, type = "prob")$P
+}
+
 evaluate_model <- function(model, data) {
   # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
   # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
@@ -40,7 +69,6 @@ evaluate_model <- function(model, data) {
   
   # Predict probabilities
   if(class(model) %in% c("citodnn", "citodnnBootstrap")){
-    data = dplyr::select(data, any_of(all.vars(model$old_formula)))
     probs = predict(model, data, type = "response")[,1]
     preds = factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
   } else {
@@ -72,56 +100,3 @@ evaluate_model <- function(model, data) {
     )
   )
 }
-
-evaluate_multiclass_model <- function(model, test_data, k) {
-  # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
-  # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
-  
-  # Precision: The proportion of true positives out of all instances predicted as positive.
-  # Formula: Precision = TP / (TP + FP)
-  
-  # Recall (Sensitivity): The proportion of true positives out of all actual positive instances.
-  # Formula: Recall = TP / (TP + FN)
-  
-  # F1 Score: The harmonic mean of Precision and Recall, balancing the two metrics.
-  # Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall)
-  target_species = unique(test_data$species)
-  checkmate::assert_character(target_species, len = 1, any.missing = F)
-  
-  new_data = dplyr::select(test_data, -species)
-  
-  # Predict probabilities
-  if(class(model) %in% c("citodnn", "citodnnBootstrap")){
-    preds_overall = predict(model, as.matrix(new_data), type = "response")
-    probs <- as.vector(preds_overall[,target_species])
-    
-    rank = apply(preds_overall, 1, function(x){         # Top-K approach
-      x_sort = sort(x, decreasing = T)
-      return(which(names(x_sort) == target_species))
-    })
-    top_k = as.character(as.numeric(rank <= k))
-    preds <- factor(top_k, levels = c("0", "1"), labels = c("A", "P"))
-  } else {
-    stop("Unsupported model type: ", class(model))
-  }
-  
-  actual <- factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
-  
-  # Calculate AUC
-  auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
-  
-  # Calculate confusion matrix
-  cm <- caret::confusionMatrix(preds, actual, positive = "P")
-  
-  # Return metrics
-  return(
-    list(
-      AUC = as.numeric(auc),
-      Accuracy = cm$overall["Accuracy"],
-      Kappa = cm$overall["Kappa"],
-      Precision = cm$byClass["Precision"],
-      Recall = cm$byClass["Recall"],
-      F1 = cm$byClass["F1"]
-    )
-  )
-}
-- 
GitLab