From 62ff20e6ef7cf27776bfff9cd063566cb03af5a4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=B6nig?= <ye87zine@usr.idiv.de>
Date: Fri, 13 Dec 2024 16:16:54 +0100
Subject: [PATCH 1/4] some progress on nested CV

---
 R/03_presence_absence_preparation.R |  15 +-
 R/04_01_ssdm_modeling.R             | 264 +++++++++++++++-------------
 R/04_03_msdm_embed_raw.R            |   2 +-
 R/utils.R                           |  10 +-
 4 files changed, 150 insertions(+), 141 deletions(-)

diff --git a/R/03_presence_absence_preparation.R b/R/03_presence_absence_preparation.R
index 9fd817c..78d9327 100644
--- a/R/03_presence_absence_preparation.R
+++ b/R/03_presence_absence_preparation.R
@@ -126,14 +126,9 @@ model_data = furrr::future_map(occs_split, .progress = TRUE, .options = furrr::f
   pa_spec = occs_spec %>% 
     dplyr::mutate(present = 1) %>% 
     bind_rows(abs_spec) 
-  
-  # Split into train and test datasets
-  train_index = createDataPartition(pa_spec$present, p = 0.7, list = FALSE)
-  pa_spec$train = 0
-  pa_spec$train[train_index] = 1
-  
+
   # Define cross-validation folds
-  folds = tryCatch({
+  folds_eval = tryCatch({
     spatial_folds = suppressMessages(
       blockCV::cv_spatial(
         pa_spec,
@@ -150,11 +145,11 @@ model_data = furrr::future_map(occs_split, .progress = TRUE, .options = furrr::f
     NA
   })
   
-  pa_spec$folds = folds
-  pa_spec$geometry = NULL
+
+  pa_spec$fold_eval = folds_eval
   
   return(pa_spec)
 })
 
 model_data = bind_rows(model_data)
-save(model_data, file = "data/r_objects/model_data.RData")
+save(model_data, file = "data/r_objects/nest_cv/model_data.RData")
diff --git a/R/04_01_ssdm_modeling.R b/R/04_01_ssdm_modeling.R
index 959f77c..ddef696 100644
--- a/R/04_01_ssdm_modeling.R
+++ b/R/04_01_ssdm_modeling.R
@@ -7,147 +7,161 @@ library(pROC)
 
 source("R/utils.R")
 
-load("data/r_objects/model_data.RData")
+load("data/r_objects/full_evaluation/model_data.RData")
 
 data_split = split(model_data, model_data$species)
 
+# Define empty result for performance eval
+performance = list(    
+  AUC = numeric(0),
+  Accuracy = numeric(0),
+  Kappa = numeric(0),
+  Precision = numeric(0),
+  Recall = numeric(0),
+  F1 = numeric(0)
+)
+
 # ----------------------------------------------------------------------#
 # Train models                                                       ####
 # ----------------------------------------------------------------------#
 future::plan("multisession", workers = 8)
 ssdm_results = furrr::future_map(data_split, .options = furrr::furrr_options(seed = 123), .f = function(data_spec){
-  # Initial check
-  if(nrow(data_spec) < 10 | anyNA(data_spec$folds)){ 
-    return(NULL)
-  }
-  
   data_spec$present_fct = factor(data_spec$present, levels = c("0", "1"), labels = c("A", "P"))
-  train_data = dplyr::filter(data_spec, train == 1)
-  test_data = dplyr::filter(data_spec, train == 0)
-  
-  # Define empty result for performance eval
-  na_performance = list(    
-    AUC = NA,
-    Accuracy = NA,
-    Kappa = NA,
-    Precision = NA,
-    Recall = NA,
-    F1 = NA
-  )
-  
-  # Define predictors
-  predictors = paste0("layer_", 1:19)
-  
-  # Define caret training routine #####
-  index_train = lapply(unique(sort(train_data$fold)), function(x){
-    return(which(train_data$fold != x))
-  })
   
-  train_ctrl = trainControl(
-    search = "grid",
-    classProbs = TRUE, 
-    index = index_train,
-    summaryFunction = twoClassSummary, 
-    savePredictions = "final"
-  )
-  
-  # Random Forest #####
-  rf_performance = tryCatch({
-    rf_grid = expand.grid(
-      mtry = c(3,7,11,15,19)                # Number of randomly selected predictors
-    )
+  # Outer CV loop
+  for(k in sort(unique(data_spec$fold_eval))){
+    data_test = dplyr::filter(data_spec, fold_eval == k)
+    data_train = dplyr::filter(data_spec, fold_eval != k)
     
-    rf_fit = caret::train(
-      x = train_data[, predictors],
-      y = train_data$present_fct,
-      method = "rf",
-      metric = "ROC",
-      tuneGrid = rf_grid,
-      trControl = train_ctrl
-    )
-    evaluate_model(rf_fit, test_data)
-  }, error = function(e){
-    na_performance
-  })
-  
-  # Gradient Boosted Machine ####
-  gbm_performance = tryCatch({
-    gbm_grid <- expand.grid(
-      n.trees = c(100, 500, 1000, 1500),       # number of trees
-      interaction.depth = c(3, 5, 7),          # Maximum depth of each tree
-      shrinkage = c(0.01, 0.005, 0.001),       # Lower learning rates
-      n.minobsinnode = c(10, 20)               # Minimum number of observations in nodes
-    )
+    if(nrow(data_test) == 0 || nrow(data_train) == 0){
+      return()
+    }
     
-    gbm_fit = train(
-      x = train_data[, predictors],
-      y = train_data$present_fct,
-      method = "gbm",
-      metric = "ROC",
-      verbose = F,
-      tuneGrid = gbm_grid,
-      trControl = train_ctrl
+    # Create inner CV folds for model training
+    cv_train = blockCV::cv_spatial(
+      data_train,
+      column = "present",
+      k = 5,
+      progress = F, plot = F, report = F
     )
-    evaluate_model(gbm_fit, test_data)
-  }, error = function(e){
-    na_performance
-  })
-  
-  # Generalized additive Model ####
-  glm_performance = tryCatch({
-    glm_fit = train(
-      x = train_data[, predictors],
-      y = train_data$present_fct,
-      method = "glm",
-      family=binomial, 
-      metric = "ROC",
-      preProcess = c("center", "scale"),
-      trControl = train_ctrl
+    data_train$fold_train = cv_train$folds_ids
+    
+    # Define caret training routine #####
+    index_train = lapply(unique(sort(data_train$fold_train)), function(x){
+      return(which(data_train$fold_train != x))
+    })
+    
+    train_ctrl = trainControl(
+      search = "grid",
+      classProbs = TRUE, 
+      index = index_train,
+      summaryFunction = twoClassSummary, 
+      savePredictions = "final"
     )
-    evaluate_model(glm_fit, test_data)
-  }, error = function(e){
-    na_performance
-  })
   
-  # Neural Network ####
-  nn_performance = tryCatch({
-    nn_fit = dnn(
-      X = train_data[, predictors],
-      Y = train_data$present,
-      hidden = c(200L, 200L, 200L),
-      loss = "binomial",
-      activation = c("sigmoid", "leaky_relu", "leaky_relu"),
-      epochs = 500L, 
-      lr = 0.02,   
-      baseloss=10,
-      batchsize=nrow(train_data)/4,
-      dropout = 0.1,  # Regularization 
-      optimizer = config_optimizer("adam", weight_decay = 0.001),
-      lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
-      early_stopping = 200, # stop training when validation loss does not decrease anymore
-      validation = 0.3, # used for early stopping and lr_scheduler 
-      device = "cuda",
-      bootstrap = 5
-    )
+    # Define predictors
+    predictors = paste0("layer_", 1:19)
     
-    evaluate_model(nn_fit, test_data)
-  }, error = function(e){
-    na_performance
-  })
-  
-  # Summarize results
-  performance_summary = tibble(
-    species = data_spec$species[1],
-    obs = nrow(data_spec),
-    model = c("SSDM_RF", "SSDM_GBM", "SSDM_GLM", "SSDM_NN"),
-    auc = c(rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC, nn_performance$AUC),
-    accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy, nn_performance$Accuracy),
-    kappa = c(rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa, nn_performance$Kappa),
-    precision = c(rf_performance$Precision, gbm_performance$Precision, glm_performance$Precision, nn_performance$Precision),
-    recall = c(rf_performance$Recall, gbm_performance$Recall, glm_performance$Recall, nn_performance$Recall),
-    f1 = c(rf_performance$F1, gbm_performance$F1, glm_performance$F1, nn_performance$F1)
-  )
-  
+    # Random Forest #####
+    rf_performance = tryCatch({
+      rf_grid = expand.grid(
+        mtry = c(3,7,11,15,19)                # Number of randomly selected predictors
+      )
+      
+      rf_fit = caret::train(
+        x = data_train[, predictors],
+        y = data_train$present_fct,
+        method = "rf",
+        metric = "ROC",
+        tuneGrid = rf_grid,
+        trControl = train_ctrl
+      )
+      evaluate_model(rf_fit, data_train)
+    }, error = function(e){
+      na_performance
+    })
+    
+    # Gradient Boosted Machine ####
+    gbm_performance = tryCatch({
+      gbm_grid <- expand.grid(
+        n.trees = c(100, 500, 1000, 1500),       # number of trees
+        interaction.depth = c(3, 5, 7),          # Maximum depth of each tree
+        shrinkage = c(0.01, 0.005, 0.001),       # Lower learning rates
+        n.minobsinnode = c(10, 20)               # Minimum number of observations in nodes
+      )
+      
+      gbm_fit = train(
+        x = data_train[, predictors],
+        y = data_train$present_fct,
+        method = "gbm",
+        metric = "ROC",
+        verbose = F,
+        tuneGrid = gbm_grid,
+        trControl = train_ctrl
+      )
+      evaluate_model(gbm_fit, data_test)
+    }, error = function(e){
+      na_performance
+    })
+    
+    # Generalized additive Model ####
+    glm_performance = tryCatch({
+      glm_fit = train(
+        x = data_train[, predictors],
+        y = data_train$present_fct,
+        method = "glm",
+        family=binomial, 
+        metric = "ROC",
+        preProcess = c("center", "scale"),
+        trControl = train_ctrl
+      )
+      evaluate_model(glm_fit, data_test)
+    }, error = function(e){
+      na_performance
+    })
+    
+    # Neural Network ####
+    nn_performance = tryCatch({
+      predictors = paste0("layer_", 1:19)
+      formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+')))
+      
+      nn_fit = dnn(
+        formula,
+        data = data_train,
+        hidden = c(500L, 500L, 500L),
+        loss = "binomial",
+        activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+        epochs = 500L, 
+        lr = 0.001,   
+        baseloss = 1,
+        batchsize = nrow(data_train),
+        dropout = 0.1,
+        burnin = 100,
+        optimizer = config_optimizer("adam", weight_decay = 0.001),
+        lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
+        early_stopping = 250,
+        validation = 0.3,
+        device = "cuda"
+      )
+      
+      evaluate_model(nn_fit, data_train)
+    }, error = function(e){
+      na_performance
+    })
+    
+    # Summarize results
+    performance_summary = tibble(
+      species = data_train$species[1],
+      obs = nrow(data_train),
+      model = c("SSDM_RF", "SSDM_GBM", "SSDM_GLM", "SSDM_NN"),
+      auc = c(rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC, nn_performance$AUC),
+      accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy, nn_performance$Accuracy),
+      kappa = c(rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa, nn_performance$Kappa),
+      precision = c(rf_performance$Precision, gbm_performance$Precision, glm_performance$Precision, nn_performance$Precision),
+      recall = c(rf_performance$Recall, gbm_performance$Recall, glm_performance$Recall, nn_performance$Recall),
+      f1 = c(rf_performance$F1, gbm_performance$F1, glm_performance$F1, nn_performance$F1)
+    )
+  }
   return(performance_summary)
 })
 
diff --git a/R/04_03_msdm_embed_raw.R b/R/04_03_msdm_embed_raw.R
index 6e7924a..03a8737 100644
--- a/R/04_03_msdm_embed_raw.R
+++ b/R/04_03_msdm_embed_raw.R
@@ -34,7 +34,7 @@ msdm_fit_embedding_raw = dnn(
   lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
   early_stopping = 250,
   validation = 0.3,
-  device = "cuda",
+  device = "cuda"
 )
 
 save(msdm_fit_embedding_raw, file = "data/r_objects/msdm_fit_embedding_raw.RData")
diff --git a/R/utils.R b/R/utils.R
index b1da26a..1c9dc02 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -25,7 +25,7 @@ expand_bbox <- function(bbox, min_span = 1, expansion = 0.25) {
   return(bbox)
 }
 
-evaluate_model <- function(model, test_data) {
+evaluate_model <- function(model, data) {
   # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
   # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
   
@@ -40,14 +40,14 @@ evaluate_model <- function(model, test_data) {
   
   # Predict probabilities
   if(class(model) %in% c("citodnn", "citodnnBootstrap")){
-    probs <- predict(model, as.matrix(test_data), type = "response")[,1]
+    probs <- predict(model, as.matrix(data), type = "response")[,1]
     preds <- factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
   } else {
-    probs <- predict(model, test_data, type = "prob")$P
-    preds <- predict(model, test_data, type = "raw")
+    probs <- predict(model, data, type = "prob")$P
+    preds <- predict(model, data, type = "raw")
   }
   
-  actual <- factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
+  actual <- factor(data$present, levels = c("0", "1"), labels = c("A", "P"))
   
   # Calculate AUC
   auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
-- 
GitLab


From 7868a722e05726595ae2834b13fd811c1f72db35 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=B6nig?= <ye87zine@usr.idiv.de>
Date: Tue, 17 Dec 2024 13:01:52 +0100
Subject: [PATCH 2/4] finalize nested cv

---
 R/04_01_ssdm_modeling.R     | 336 +++++++++++++++++++-----------------
 R/05_02_MSDM_comparison.qmd |  28 +--
 R/utils.R                   |   2 +-
 3 files changed, 197 insertions(+), 169 deletions(-)

diff --git a/R/04_01_ssdm_modeling.R b/R/04_01_ssdm_modeling.R
index ddef696..efc6107 100644
--- a/R/04_01_ssdm_modeling.R
+++ b/R/04_01_ssdm_modeling.R
@@ -1,170 +1,194 @@
-library(dplyr)
-library(tidyr)
 library(furrr)
-library(caret)
-library(cito)
-library(pROC)
+library(progressr)
+library(dplyr)
 
 source("R/utils.R")
 
-load("data/r_objects/full_evaluation/model_data.RData")
-
-data_split = split(model_data, model_data$species)
-
-# Define empty result for performance eval
-performance = list(    
-  AUC = numeric(0),
-  Accuracy = numeric(0),
-  Kappa = numeric(0),
-  Precision = numeric(0),
-  Recall = numeric(0),
-  F1 = numeric(0)
-)
-
 # ----------------------------------------------------------------------#
-# Train models                                                       ####
+# Define Training function                                           ####
 # ----------------------------------------------------------------------#
-future::plan("multisession", workers = 8)
-ssdm_results = furrr::future_map(data_split, .options = furrr::furrr_options(seed = 123), .f = function(data_spec){
-  data_spec$present_fct = factor(data_spec$present, levels = c("0", "1"), labels = c("A", "P"))
+train_models = function(data_split){
+  future::plan("multisession", workers = 16)
+  p = progressor(along = data_split)
   
-  # Outer CV loop
-  for(k in sort(unique(data_spec$fold_eval))){
-    data_test = dplyr::filter(data_spec, fold_eval == k)
-    data_train = dplyr::filter(data_spec, fold_eval != k)
-    
-    if(nrow(data_test) == 0 || nrow(data_train) == 0){
-      return()
-    }
-    
-    # Create inner CV folds for model training
-    cv_train = blockCV::cv_spatial(
-      data_train,
-      column = "present",
-      k = 5,
-      progress = F, plot = F, report = F
-    )
-    data_train$fold_train = cv_train$folds_ids
-    
-    # Define caret training routine #####
-    index_train = lapply(unique(sort(data_train$fold_train)), function(x){
-      return(which(data_train$fold_train != x))
-    })
-    
-    train_ctrl = trainControl(
-      search = "grid",
-      classProbs = TRUE, 
-      index = index_train,
-      summaryFunction = twoClassSummary, 
-      savePredictions = "final"
-    )
-  
-    # Define predictors
-    predictors = paste0("layer_", 1:19)
-    
-    # Random Forest #####
-    rf_performance = tryCatch({
-      rf_grid = expand.grid(
-        mtry = c(3,7,11,15,19)                # Number of randomly selected predictors
-      )
+  ssdm_results_list = furrr::future_map(
+    .x = data_split, 
+    .options = furrr::furrr_options(
+      seed = 123,
+      packages = c("dplyr", "tidyr", "sf", "caret", "cito", "pROC")
+    ),
+    .f = function(data_spec){
+      spec = data_spec$species[1]
+      p(sprintf("spec=%s", spec))
       
-      rf_fit = caret::train(
-        x = data_train[, predictors],
-        y = data_train$present_fct,
-        method = "rf",
-        metric = "ROC",
-        tuneGrid = rf_grid,
-        trControl = train_ctrl
-      )
-      evaluate_model(rf_fit, data_train)
-    }, error = function(e){
-      na_performance
-    })
-    
-    # Gradient Boosted Machine ####
-    gbm_performance = tryCatch({
-      gbm_grid <- expand.grid(
-        n.trees = c(100, 500, 1000, 1500),       # number of trees
-        interaction.depth = c(3, 5, 7),          # Maximum depth of each tree
-        shrinkage = c(0.01, 0.005, 0.001),       # Lower learning rates
-        n.minobsinnode = c(10, 20)               # Minimum number of observations in nodes
-      )
+      if(all(is.na(data_spec$fold_eval))){
+        warning("Too few samples")
+        return()
+      }
       
-      gbm_fit = train(
-        x = data_train[, predictors],
-        y = data_train$present_fct,
-        method = "gbm",
-        metric = "ROC",
-        verbose = F,
-        tuneGrid = gbm_grid,
-        trControl = train_ctrl
-      )
-      evaluate_model(gbm_fit, data_test)
-    }, error = function(e){
-      na_performance
-    })
-    
-    # Generalized additive Model ####
-    glm_performance = tryCatch({
-      glm_fit = train(
-        x = data_train[, predictors],
-        y = data_train$present_fct,
-        method = "glm",
-        family=binomial, 
-        metric = "ROC",
-        preProcess = c("center", "scale"),
-        trControl = train_ctrl
+      # Define empty result for performance eval
+      na_performance = list(    
+        AUC = NA,
+        Accuracy = NA,
+        Kappa = NA,
+        Precision = NA,
+        Recall = NA,
+        F1 = NA
       )
-      evaluate_model(glm_fit, data_test)
-    }, error = function(e){
-      na_performance
-    })
-    
-    # Neural Network ####
-    nn_performance = tryCatch({
-      predictors = paste0("layer_", 1:19)
-      formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+')))
       
-      nn_fit = dnn(
-        formula,
-        data = data_train,
-        hidden = c(500L, 500L, 500L),
-        loss = "binomial",
-        activation = c("sigmoid", "leaky_relu", "leaky_relu"),
-        epochs = 500L, 
-        lr = 0.001,   
-        baseloss = 1,
-        batchsize = nrow(data_train),
-        dropout = 0.1,
-        burnin = 100,
-        optimizer = config_optimizer("adam", weight_decay = 0.001),
-        lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
-        early_stopping = 250,
-        validation = 0.3,
-        device = "cuda"
-      )
+      # Create factor presence column 
+      data_spec$present_fct = factor(data_spec$present, levels = c("0", "1"), labels = c("A", "P"))
       
-      evaluate_model(nn_fit, data_train)
-    }, error = function(e){
-      na_performance
-    })
-    
-    # Summarize results
-    performance_summary = tibble(
-      species = data_train$species[1],
-      obs = nrow(data_train),
-      model = c("SSDM_RF", "SSDM_GBM", "SSDM_GLM", "SSDM_NN"),
-      auc = c(rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC, nn_performance$AUC),
-      accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy, nn_performance$Accuracy),
-      kappa = c(rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa, nn_performance$Kappa),
-      precision = c(rf_performance$Precision, gbm_performance$Precision, glm_performance$Precision, nn_performance$Precision),
-      recall = c(rf_performance$Recall, gbm_performance$Recall, glm_performance$Recall, nn_performance$Recall),
-      f1 = c(rf_performance$F1, gbm_performance$F1, glm_performance$F1, nn_performance$F1)
-    )
-  }
-  return(performance_summary)
-})
+      # Outer CV loop (for averaging performance metrics)
+      performance_cv = lapply(sort(unique(data_spec$fold_eval)), function(k){
+        data_test = dplyr::filter(data_spec, fold_eval == k)
+        data_train = dplyr::filter(data_spec, fold_eval != k)
+        
+        if(nrow(data_test) == 0 || nrow(data_train) == 0){
+          warning("Too few samples")
+        }
+        
+        # Create inner CV folds for model training
+        cv_train = blockCV::cv_spatial(
+          data_train,
+          column = "present",
+          k = 5,
+          progress = F, plot = F, report = F
+        )
+        data_train$fold_train = cv_train$folds_ids
+        
+        # Drop geometry
+        data_train$geometry = NULL
+        data_test$geometry = NULL
+        
+        # Define caret training routine #####
+        index_train = lapply(unique(sort(data_train$fold_train)), function(x){
+          return(which(data_train$fold_train != x))
+        })
+        
+        train_ctrl = trainControl(
+          search = "random",
+          classProbs = TRUE, 
+          index = index_train,
+          summaryFunction = twoClassSummary, 
+          savePredictions = "final",
+        )
+        
+        # Define predictors
+        predictors = paste0("layer_", 1:19)
+        
+        # Random Forest #####
+        rf_performance = tryCatch({
+          # Fit model
+          rf_fit = caret::train(
+            x = data_train[, predictors],
+            y = data_train$present_fct,
+            method = "rf",
+            metric = "AUC",
+            trControl = train_ctrl,
+            tuneLength = 4,
+            verbose = F
+          )
+          
+          evaluate_model(rf_fit, data_test)
+        }, error = function(e){
+          na_performance
+        })
+        
+        # Gradient Boosted Machine ####
+        gbm_performance = tryCatch({
+          gbm_fit = train(
+            x = data_train[, predictors],
+            y = data_train$present_fct,
+            method = "gbm",
+            metric = "AUC",
+            trControl = train_ctrl,
+            tuneLength = 4,
+            verbose = F
+          )
+          evaluate_model(gbm_fit, data_test)
+        }, error = function(e){
+          na_performance
+        })
+        
+        # Generalized Additive Model ####
+        gam_performance = tryCatch({
+          gam_fit = train(
+            x = data_train[, predictors],
+            y = data_train$present_fct,
+            method = "gamSpline",
+            metric = "AUC",
+            tuneLength = 4,
+            trControl = train_ctrl
+          )
+          evaluate_model(gam_fit, data_test)
+        }, error = function(e){
+          na_performance
+        })
+        
+        # Neural Network ####
+        nn_performance = tryCatch({
+          formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+')))
+          
+          nn_fit = dnn(
+            formula,
+            data = data_train,
+            hidden = c(100L, 100L, 100L),
+            loss = "binomial",
+            activation = c("leaky_relu", "leaky_relu", "leaky_relu"),
+            epochs = 400L, 
+            burnin = 100L,
+            lr = 0.001,   
+            batchsize = max(nrow(data_train)/10, 64),
+            dropout = 0.1,
+            optimizer = config_optimizer("adam", weight_decay = 0.001),
+            lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 50, factor = 0.7),
+            early_stopping = 100,
+            validation = 0.2,
+            device = "cuda",
+            verbose = F,
+            plot = F
+          )
+          
+          evaluate_model(nn_fit, data_test)
+        }, error = function(e){
+          na_performance
+        })
+        
+        # Summarize results
+        performance_summary = tibble(
+          species = spec,
+          obs = nrow(data_train),
+          fold_eval = k,
+          model = c("RF", "GBM", "GAM", "NN"),
+          auc = c(rf_performance$AUC, gbm_performance$AUC, gam_performance$AUC, nn_performance$AUC),
+          accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, gam_performance$Accuracy, nn_performance$Accuracy),
+          kappa = c(rf_performance$Kappa, gbm_performance$Kappa, gam_performance$Kappa, nn_performance$Kappa),
+          precision = c(rf_performance$Precision, gbm_performance$Precision, gam_performance$Precision, nn_performance$Precision),
+          recall = c(rf_performance$Recall, gbm_performance$Recall, gam_performance$Recall, nn_performance$Recall),
+          f1 = c(rf_performance$F1, gbm_performance$F1, gam_performance$F1, nn_performance$F1)
+        ) %>% 
+          pivot_longer(all_of(c("auc", "accuracy", "kappa", "precision", "recall", "f1")), names_to = "metric", values_to = "value")
+        
+        return(performance_summary)
+      })
+      
+      # Combine and save evaluation results
+      performance_spec = bind_rows(performance_cv)
+      filename = paste0(gsub(" ", "_", spec), "_results.RData")
+      save(performance_spec, file = file.path("data/r_objects/nested_cv/model_results/", filename))
+    }
+  )
+}
+
+# ----------------------------------------------------------------------#
+# Run training                                                       ####
+# ----------------------------------------------------------------------#
+load("data/r_objects/nested_cv/model_data.RData")
+
+handlers(global = TRUE)
+handlers("progress")
 
-ssdm_results = bind_rows(ssdm_results)
+data_split = group_split(model_data, species)
 
-save(ssdm_results, file = "data/r_objects/ssdm_results.RData")
+train_models(data_split)
\ No newline at end of file
diff --git a/R/05_02_MSDM_comparison.qmd b/R/05_02_MSDM_comparison.qmd
index 5b14c8a..b70af79 100644
--- a/R/05_02_MSDM_comparison.qmd
+++ b/R/05_02_MSDM_comparison.qmd
@@ -12,13 +12,15 @@ library(plotly)
 library(DT)
 library(shiny)
 
-load("../data/r_objects/msdm_results_embedding_raw.RData")
-load("../data/r_objects/msdm_results_embedding_traits_static.RData")
-load("../data/r_objects/msdm_results_embedding_traits_trained.RData")
-load("../data/r_objects/msdm_results_embedding_phylo_static.RData")
-load("../data/r_objects/msdm_results_embedding_phylo_trained.RData")
-load("../data/r_objects/msdm_results_embedding_range_static.RData")
-load("../data/r_objects/msdm_results_embedding_range_trained.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_raw.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_traits_static.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_traits_trained.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_phylo_static.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_phylo_trained.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_range_static.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_range_trained.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_multi_nolonlat.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_multi_lonlat.RData")
 
 sf::sf_use_s2(use_s2 = FALSE)
 ```
@@ -40,7 +42,9 @@ results_embedding_informed = c(
   "msdm_results_embedding_range_static",
   "msdm_results_embedding_range_trained",
   "msdm_results_embedding_traits_static",
-  "msdm_results_embedding_traits_trained"
+  "msdm_results_embedding_traits_trained",
+  "msdm_results_embedding_multi_nolonlat",
+  "msdm_results_embedding_multi_lonlat"
 )
 
 results_embedding_informed_merged = lapply(results_embedding_informed, function(df_name){
@@ -63,7 +67,9 @@ results_final = results_embedding_raw %>%
       "MSDM_embed_informed_traits_static" = "M_TS",
       "MSDM_embed_informed_traits_trained" = "M_TT",
       "MSDM_embed_informed_range_static" = "M_RS",
-      "MSDM_embed_informed_range_trained" = "M_RT"
+      "MSDM_embed_informed_range_trained" = "M_RT",
+      "MSDM_embed_informed_multi_nolonlat" = "M_MT_nolonlat",
+      "MSDM_embed_informed_multi_lonlat" = "M_MT_lonlat"
     ),
     across(all_of(focal_metrics), round, 3)
   )
@@ -274,13 +280,12 @@ for(model_name in unique(df_plot$model)){
 bslib::card(plot, full_screen = T)
 ```
 
-
 ## *Relative Performance*
 
 ```{r delta, echo = FALSE, message=FALSE, warnings=FALSE}
 results_ranked_obs = results_final_long %>% 
   group_by(species,  metric) %>% 
-  mutate(rank = rev(rank(value)))
+  mutate(rank = rank(value))
 
 reglines = results_ranked_obs %>%
   group_by(model, metric) %>%
@@ -426,7 +431,6 @@ bslib::card(plot, full_screen = T)
 ```
 :::
 
-
 ## *Trait space*
 
 ```{r trait_pca, echo = FALSE, message=FALSE, warnings=FALSE}
diff --git a/R/utils.R b/R/utils.R
index 1c9dc02..a7436a5 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -40,7 +40,7 @@ evaluate_model <- function(model, data) {
   
   # Predict probabilities
   if(class(model) %in% c("citodnn", "citodnnBootstrap")){
-    probs <- predict(model, as.matrix(data), type = "response")[,1]
+    probs <- predict(model, data, type = "response")[,1]
     preds <- factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
   } else {
     probs <- predict(model, data, type = "prob")$P
-- 
GitLab


From fae69c9471f808f2018c2644840c5a36ae101d3a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=B6nig?= <ye87zine@usr.idiv.de>
Date: Thu, 9 Jan 2025 22:33:02 +0100
Subject: [PATCH 3/4] work on model report

---
 .gitignore                       |   1 +
 R/04_01_ssdm_modeling.R          |  13 +-
 R/05_01_performance_analysis.qmd | 367 +++++++++++++++++--------------
 3 files changed, 212 insertions(+), 169 deletions(-)

diff --git a/.gitignore b/.gitignore
index ac6cec0..3acf079 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,5 +11,6 @@ renv/cache/
 
 # Data files
 data/
+plots/
 R/*/
 R/*.html
\ No newline at end of file
diff --git a/R/04_01_ssdm_modeling.R b/R/04_01_ssdm_modeling.R
index efc6107..5a0de5e 100644
--- a/R/04_01_ssdm_modeling.R
+++ b/R/04_01_ssdm_modeling.R
@@ -191,4 +191,15 @@ handlers("progress")
 
 data_split = group_split(model_data, species)
 
-train_models(data_split)
\ No newline at end of file
+train_models(data_split)
+
+# ----------------------------------------------------------------------#
+# Combine results                                                    ####
+# ----------------------------------------------------------------------#
+files = list.files("data/r_objects/nested_cv/model_results/", full.names = T)
+ssdm_results = lapply(files, function(f){load(f); return(performance_spec)}) %>% 
+  bind_rows() %>% 
+  dplyr::group_by(species, model, metric) %>% 
+  dplyr::summarize(value = mean(value, na.rm =T))
+
+save(ssdm_results, file = "data/r_objects/nested_cv/ssdm_results.RData")
diff --git a/R/05_01_performance_analysis.qmd b/R/05_01_performance_analysis.qmd
index e2c2f33..44b9b8e 100644
--- a/R/05_01_performance_analysis.qmd
+++ b/R/05_01_performance_analysis.qmd
@@ -11,12 +11,12 @@ library(sf)
 library(plotly)
 library(DT)
 
-
-load("../data/r_objects/ssdm_results.RData")
-load("../data/r_objects/msdm_fit.RData")
-load("../data/r_objects/msdm_results.RData")
-load("../data/r_objects/msdm_results_embedding_trained.RData")
-load("../data/r_objects/msdm_results_multiclass.RData.")
+load("../data/r_objects/nested_cv/model_data.RData")
+load("../data/r_objects/nested_cv/ssdm_results.RData")
+#load("../data/r_objects/msdm_fit.RData")
+#load("../data/r_objects/msdm_results.RData")
+#load("../data/r_objects/msdm_results_embedding_trained.RData")
+#load("../data/r_objects/msdm_results_multiclass.RData.")
 load("../data/r_objects/range_maps.RData")
 load("../data/r_objects/range_maps_gridded.RData")
 load("../data/r_objects/occs_final.RData")
@@ -25,14 +25,13 @@ sf::sf_use_s2(use_s2 = FALSE)
 ```
 
 ```{r globals, echo = FALSE, include = FALSE}
-# Select metrics
-focal_metrics = c("auc", "f1", "accuracy")  # There's a weird bug in plotly that scrambles up lines when using more than three groups
+# Count occs per species
+obs_count = model_data %>% 
+  sf::st_drop_geometry() %>% 
+  dplyr::filter(present == 1) %>% 
+  dplyr::group_by(species) %>% 
+  dplyr::summarise(obs = n())
 
-# Dropdown options
-plotly_buttons = list()
-for(metric in focal_metrics){
-  plotly_buttons[[length(plotly_buttons) + 1]] = list(method = "restyle", args = list("transforms[0].value", metric), label = metric)
-}
 
 # Regression functions
 asym_regression = function(x, y){
@@ -44,8 +43,8 @@ asym_regression = function(x, y){
   )
 }
 
-lin_regression = function(x, y){
-  glm_fit = suppressWarnings(glm(y~x, family = "binomial"))
+lin_regression = function(x, y, family = "binomial"){
+  glm_fit = suppressWarnings(glm(y~x, family = family))
   new_x = seq(min(x), max(x), length.out = 100)
   data.frame(
     x = new_x,
@@ -54,187 +53,220 @@ lin_regression = function(x, y){
 }
 
 # Performance table
-performance = bind_rows(ssdm_results, msdm_results, msdm_results_embedding_trained, msdm_results_multiclass) %>% 
-  pivot_longer(c(auc, accuracy, kappa, precision, recall, f1), names_to = "metric") %>% 
-  dplyr::filter(!is.na(value)) %>% 
+performance = bind_rows(ssdm_results) %>%
+  ungroup() %>% 
   dplyr::mutate(
-    metric = factor(metric, levels = c("auc", "kappa", "f1", "accuracy", "precision", "recall")),
-    value = round(pmax(value, 0, na.rm = T), 3) # Fix one weird instance of f1 < 0
-  ) %>% 
-  dplyr::filter(metric %in% focal_metrics)
+    value = case_when(
+      ((is.na(value) | is.nan(value)) & metric %in% c("auc", "f1", "accurracy", "precision", "recall")) ~ 0.5,
+      ((is.na(value) | is.nan(value)) & metric %in% c("kappa")) ~ 0,
+      .default = value
+    )
+  )
+
+focal_metrics = unique(performance$metric)
 ```
 
-## Summary
+## *Summary*
 
-This document summarizes the performance of different sSDM and mSDM algorithms for `r I(length(unique(performance$species)))` South American mammal species. Model performance is evaluated on `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) and analyzed along five potential influence factors (number of records, range size, range coverage, range coverage bias, and functional group). The comparison of sSDM vs mSDM approaches is of particular interest.
+*This document summarizes the performance of different sSDM and mSDM algorithms for `r I(length(unique(performance$species)))` South American mammal species. Model performance is evaluated on `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) and analyzed along five potential influence factors (number of records, range size, range coverage, range coverage bias, and functional group). The comparison of sSDM vs mSDM approaches is of particular interest.*
 
-Code can be found on [GitLab](https://git.idiv.de/ye87zine/symobio-modeling).
+*Code can be found on [GitLab](https://git.idiv.de/ye87zine/symobio-modeling).*
 
-### Modeling overview:
+### *Modeling overview:*
 
-#### General decisions
+#### *General decisions*
 
--   Randomly sampled pseudo-absences from expanded area of extent of occurrence records (×1.25)
--   Balanced presences and absences for each species
--   Predictors: all 19 CHELSA bioclim variables
--   70/30 Split of training vs. test data (except for NN models)
+-   *Randomly sampled pseudo-absences from expanded area of extent of occurrence records (×1.25)*
+-   *Balanced presences and absences for each species*
+-   *Predictors: all 19 CHELSA bioclim variables*
+-   *70/30 Split of training vs. test data (except for NN models)*
 
-#### sSDM Algorithms
+#### *sSDM Algorithms*
 
-Random Forest (**SSDM_RF**)
+*Random Forest (**SSDM_RF**)*
 
--   Hyperparameter tuning of `mtry`
--   Spatial block cross-validation during training
+-   *Hyperparameter tuning of `mtry`*
+-   *Spatial block cross-validation during training*
 
-Generalized boosted machine (**SSDM_GBM**)
+*Generalized boosted machine (**SSDM_GBM**)*
 
--   Hyperparameter tuning across `n.trees` , `interaction.depth` , `shrinkage`, `n.minobsinnode`
--   Spatial block cross-validation during training
+-   *Hyperparameter tuning across `n.trees` , `interaction.depth` , `shrinkage`, `n.minobsinnode`*
+-   *Spatial block cross-validation during training*
 
-Generalized Linear Model (**SSDM_GLM**)
+*Generalized Linear Model (**SSDM_GLM**)*
 
--   Logistic model with binomial link function
--   Spatial block cross-validation during training
+-   *Logistic model with binomial link function*
+-   *Spatial block cross-validation during training*
 
-Neural Netwok (**SSDM_NN**)
+*Neural Netwok (**SSDM_NN**)*
 
--   Three hidden layers, leaky ReLu activations, binomial loss
--   no spatial block cross-validation during training
+-   *Three hidden layers, leaky ReLu activations, binomial loss*
+-   *no spatial block cross-validation during training*
 
-#### mSDM Algorithms
+#### *mSDM Algorithms*
 
-Binary Neural Network with species embedding (**MSDM_embed**)
+*Binary Neural Network with species embedding (**MSDM_embed**)*
 
--   definition: presence \~ environment + embedding(species)
--   prediction: probability of occurrence given a set of (environmental) inputs and species identity
--   embedding initialized at random
--   three hidden layers, sigmoid + leaky ReLu activations, binomial loss
+-   *definition: presence \~ environment + embedding(species)*
+-   *prediction: probability of occurrence given a set of (environmental) inputs and species identity*
+-   *embedding initialized at random*
+-   *three hidden layers, sigmoid + leaky ReLu activations, binomial loss*
 
-Binary Neural Network with trait-informed species embedding (**MSDM_embed_informed_trained**)
+*Binary Neural Network with trait-informed species embedding (**MSDM_embed_informed_trained**)*
 
--   definition: presence \~ environment + embedding(species)
--   prediction: probability of occurrence given a set of (environmental) inputs and species identity
--   embedding initialized using eigenvectors of functional distance matrix, then further training on data
--   three hidden layers, sigmoid + leaky ReLu activations, binomial loss
+-   *definition: presence \~ environment + embedding(species)*
+-   *prediction: probability of occurrence given a set of (environmental) inputs and species identity*
+-   *embedding initialized using eigenvectors of functional distance matrix, then further training on data*
+-   *three hidden layers, sigmoid + leaky ReLu activations, binomial loss*
 
-Multi-Class Neural Network (**MSDM_multiclass**)
+*Multi-Class Neural Network (**MSDM_multiclass**)*
 
--   definition: species identity \~ environment
--   prediction: probability distribution across all observed species given a set of (environmental) inputs
--   presence-only data in training
--   three hidden layers, leaky ReLu activations, softmax loss
--   Top-k based evaluation (k=10, P/A \~ target species in / not among top 10 predictions)
+-   *definition: species identity \~ environment*
+-   *prediction: probability distribution across all observed species given a set of (environmental) inputs*
+-   *presence-only data in training*
+-   *three hidden layers, leaky ReLu activations, softmax loss*
+-   *Top-k based evaluation (k=10, P/A \~ target species in / not among top 10 predictions)*
 
-### Key findings:
+### *Key findings:*
 
--   sSDM algorithms (RF, GBM) outperformed mSDMs in most cases
--   mSDMs showed indications of better performance for rare species (\< 10-20 occurrences)
--   More occurrence records and larger range sizes tended to improve model performance
--   Higher range coverage correlated with better performance
--   Range coverage bias and functional group showed some impact but were less consistent
--   Convergence problems hampered NN sSDM performance
+-   *sSDM algorithms (RF, GBM) outperformed mSDMs in most cases*
+-   *mSDMs showed indications of better performance for rare species (\< 10-20 occurrences)*
+-   *More occurrence records and larger range sizes tended to improve model performance*
+-   *Higher range coverage correlated with better performance*
+-   *Range coverage bias and functional group showed some impact but were less consistent*
+-   *Convergence problems hampered NN sSDM performance*
 
-## Analysis
+## *Analysis*
 
-The table below shows the analysed modeling results.
+*The table below shows the analysed modeling results.*
 
 ```{r performance, echo = FALSE, message=FALSE, warnings=FALSE}
-DT::datatable(performance)
+DT::datatable(performance) %>% 
+  formatRound(columns="value", digits=3)
 ```
 
-### Number of records
+### *Number of records*
 
--   Model performance was generally better for species with more observations
--   Very poor performance below 50-100 observations
+-   *Model performance was generally better for species with more observations*
+-   *Very poor performance below 50-100 observations*
 
 ```{r number_of_records, echo = FALSE, message=FALSE, warnings=FALSE}
-df_plot = performance
-
-# Calculate regression lines for each model and metric combination
-suppressWarnings({
-  regression_lines = df_plot %>%
-    group_by(model, metric) %>%
-    group_modify(~asym_regression(.x$obs, .x$value))
-})
-
-# Create base plot
-plot <- plot_ly() %>% 
-  layout(
-    title = "Model Performance vs. Number of observations",
-    xaxis = list(title = "Number of observations", type = "log"),
-    yaxis = list(title = "Value"),
-    legend = list(x = 1.1, y = 0.5),  # Move legend to the right of the plot
-    margin = list(r = 150),  # Add right margin to accommodate legend
-    hovermode = 'closest',
-    updatemenus = list(
-      list(
-        type = "dropdown",
-        active = 0,
-        buttons = plotly_buttons
-      )
+plot_performance_over_frequency = function(df_plot, metric) {
+  df_plot = dplyr::filter(df_plot, metric == !!metric)
+  
+  # Calculate regression lines for each model and metric combination
+  suppressWarnings({
+    regression_lines = df_plot %>%
+      group_by(model) %>%
+      group_modify( ~ asym_regression(.x$obs, .x$value))
+  })
+  
+  # Create base plot
+  plot <- plot_ly() %>%
+    layout(
+      title = "Model Performance vs. Number of observations",
+      xaxis = list(title = "Number of observations", type = "log"),
+      yaxis = list(title = metric),
+      legend = list(x = 1.1, y = 0.5), # Move legend to the right of the plot
+      margin = list(r = 150), # Add right margin to accommodate legend
+      hovermode = 'closest'
     )
-  )
-
-# Points
-for (model_name in unique(df_plot$model)) {
-  plot = plot %>%
-    add_markers(
-      data = filter(df_plot, model == model_name),
-      x = ~obs,
-      y = ~value,
-      color = model_name,  # Set color to match legendgroup
-      legendgroup = model_name,
-      opacity = 0.6,
-      name = ~model,
-      hoverinfo = 'text',
-      text = ~paste("Species:", species, "<br>Observations:", obs, "<br>Value:", round(value, 3)),
-      transforms = list(
-        list(
-          type = 'filter',
-          target = ~metric,
-          operation = '=',
-          value = focal_metrics[1]
+  
+  # Points
+  for (model_name in unique(df_plot$model)) {
+    plot = plot %>%
+      add_markers(
+        data = filter(df_plot, model == model_name, metric %in% focal_metrics),
+        x = ~ obs,
+        y = ~ value,
+        color = model_name, # Set color to match legendgroup
+        legendgroup = model_name,
+        opacity = 0.6,
+        name = ~ model,
+        hoverinfo = 'text',
+        text = ~ paste(
+          "Species:", species, "<br>Observations:", obs, "<br>Value:", round(value, 3)
         )
       )
-    )
-}
-
-# Add regression lines
-for(model_name in unique(df_plot$model)){
-  reg_data = dplyr::filter(regression_lines, model == model_name)
-  plot = plot %>% 
-    add_lines(
-      data = reg_data,
-      x = ~x,
-      y = ~fit,
-      color = model_name,  # Set color to match legendgroup
-      legendgroup = model_name,
-      name = paste(model_name, '(fit)'),
-      showlegend = FALSE,
-      transforms = list(
-        list(
-          type = 'filter',
-          target = ~metric,
-          operation = '=',
-          value = focal_metrics[1]
-        )
+  }
+  
+  # Add regression lines
+  for (model_name in unique(df_plot$model)) {
+    reg_data = dplyr::filter(regression_lines, model == model_name)
+    plot = plot %>%
+      add_lines(
+        data = reg_data,
+        x = ~ x,
+        y = ~ fit,
+        color = model_name, # Set color to match legendgroup
+        legendgroup = model_name,
+        name = paste(model_name, '(fit)'),
+        showlegend = FALSE
       )
-    )
+  }
+  
+  return(plot)
 }
 
+df_plot = performance %>% dplyr::left_join(obs_count, by = "species")
+```
+
+::: panel-tabset
+#### *AUC*
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "auc")
+bslib::card(plot, full_screen = T)
+```
+
+#### *F1*
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "f1")
+bslib::card(plot, full_screen = T)
+```
+
+#### *Cohen's kappa*
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "kappa")
+bslib::card(plot, full_screen = T)
+```
+
+#### *Accurracy*
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "accuracy")
 bslib::card(plot, full_screen = T)
 ```
 
-### Range characteristics
+#### *Precision*
 
-#### Range size
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "precision")
+bslib::card(plot, full_screen = T)
+```
+
+#### *Recall*
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "recall")
+bslib::card(plot, full_screen = T)
+```
+:::
 
-Range size was calculated based on polygon layers from the IUCN Red List of Threatened Species (2016).
 
--   Model performance tended to be slightly higher for species with larger range size
--   Only RF shows continuous performance improvements beyond range sizes of \~5M km²
 
-```{r range_size, echo = FALSE, message=FALSE, warnings=FALSE}
+### *Range characteristics*
+
+#### *Range size*
+
+*Range size was calculated based on polygon layers from the IUCN Red List of Threatened Species (2016).*
+
+-   *Model performance tended to be slightly higher for species with larger range size*
+-   *Only RF shows continuous performance improvements beyond range sizes of \~5M km²*
+
+```{r range_size, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
 
 df_join = range_maps %>% 
   dplyr::mutate(range_size = as.numeric(st_area(range_maps) / 1000000)) %>%  # range in sqkm
@@ -317,17 +349,17 @@ for (model_name in unique(df_plot$model)) {
 bslib::card(plot, full_screen = T)
 ```
 
-#### Range coverage
+#### *Range coverage*
 
-Species ranges were split into continuous hexagonal grid cells of 1 degree diameter. Range coverage was then calculated as the number of grid cells containing at least one occurrence record divided by the number of total grid cells.
+*Species ranges were split into continuous hexagonal grid cells of 1 degree diameter. Range coverage was then calculated as the number of grid cells containing at least one occurrence record divided by the number of total grid cells.*
 
 $$
 RangeCoverage = \frac{N_{cells\_occ}}{N_{cells\_total}}
 $$
 
--   Models for species with higher range coverage showed slightly better performance
+-   *Models for species with higher range coverage showed slightly better performance*
 
-```{r range_coverage, echo = FALSE, message=FALSE, warnings=FALSE}
+```{r range_coverage, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
 df_cells_total = range_maps_gridded %>%
   dplyr::rename("species" = name_matched) %>% 
   group_by(species) %>%
@@ -423,19 +455,19 @@ for (model_name in unique(df_plot$model)) {
 bslib::card(plot, full_screen = T)
 ```
 
-#### Range coverage bias
+#### *Range coverage bias*
 
-Range coverage bias was calculated as 1 minus the ratio of the actual range coverage and the hypothetical range coverage if all observations were maximally spread out across the range.
+*Range coverage bias was calculated as 1 minus the ratio of the actual range coverage and the hypothetical range coverage if all observations were maximally spread out across the range.*
 
 $$
 RangeCoverageBias = 1 - \frac{RangeCoverage}{min({N_{obs\_total}} / {N_{cells\_total}}, 1)}
 $$
 
-Higher bias values indicate that occurrence records are spatially more clustered within the range of the species.
+*Higher bias values indicate that occurrence records are spatially more clustered within the range of the species.*
 
--   There was no strong relationship between range coverage bias and model performance
+-   *There was no strong relationship between range coverage bias and model performance*
 
-```{r range_coverage_bias, echo = FALSE, message=FALSE, warnings=FALSE}
+```{r range_coverage_bias, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
 df_occs_total = occs_final %>% 
   st_drop_geometry() %>% 
   group_by(species) %>% 
@@ -524,20 +556,20 @@ for (model_name in unique(df_plot$model)) {
 bslib::card(plot, full_screen = T)
 ```
 
-### Functional group
+### *Functional group*
 
-Functional groups were assigned based on taxonomic order. The following groupings were used:
+*Functional groups were assigned based on taxonomic order. The following groupings were used:*
 
-| Functional group      | Taxomic orders                                                        |
-|------------------|-----------------------------------------------------|
-| large ground-dwelling | Carnivora, Artiodactyla, Cingulata, Perissodactyla                    |
-| small ground-dwelling | Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha |
-| arboreal              | Primates, Pilosa                                                      |
-| flying                | Chiroptera                                                            |
+| *Functional group*      | *Taxomic orders*                                                        |
+|-------------------|-----------------------------------------------------|
+| *large ground-dwelling* | *Carnivora, Artiodactyla, Cingulata, Perissodactyla*                    |
+| *small ground-dwelling* | *Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha* |
+| *arboreal*              | *Primates, Pilosa*                                                      |
+| *flying*                | *Chiroptera*                                                            |
 
--   Models for bats tended to perform slightly worse than for other groups.
+-   *Models for bats tended to perform slightly worse than for other groups.*
 
-```{r functional_groups, echo = FALSE, message=FALSE, warnings=FALSE}
+```{r functional_groups, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
 df_plot = performance %>% 
   dplyr::left_join(functional_groups, by = c("species" = "name_matched"))
 
@@ -582,4 +614,3 @@ plot <- plot %>%
 
 bslib::card(plot, full_screen = T)
 ```
-
-- 
GitLab


From 2cfd549197dc09a54d7c0e68ac39c11df4f0213c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=B6nig?= <ye87zine@usr.idiv.de>
Date: Wed, 15 Jan 2025 14:57:57 +0100
Subject: [PATCH 4/4] some changes to modeling pipelines and quick additions
 for updated report

---
 R/01_01_range_map_preparation.R          |   4 +-
 R/03_02_absence_preparation.R            |  17 +-
 R/04_01_ssdm_modeling.R                  |  67 ++++----
 R/04_03_msdm_embed_raw.R                 |  40 +++--
 R/04_07_msdm_embed_multi_nolonlat.R      |  20 ++-
 R/04_07_msdm_oneoff.R                    | 111 +++++++++++++
 R/05_01_performance_analysis.qmd         | 189 +++++++++++-----------
 R/05_01_performance_analysis_carsten.qmd | 190 +++++++++++++++++++++++
 R/_publish.yml                           |   4 +
 R/utils.R                                |   2 +-
 renv.lock                                | 151 ++++++++++++++++++
 11 files changed, 637 insertions(+), 158 deletions(-)
 create mode 100644 R/04_07_msdm_oneoff.R
 create mode 100644 R/05_01_performance_analysis_carsten.qmd

diff --git a/R/01_01_range_map_preparation.R b/R/01_01_range_map_preparation.R
index 0a07176..eb33fdc 100644
--- a/R/01_01_range_map_preparation.R
+++ b/R/01_01_range_map_preparation.R
@@ -79,8 +79,8 @@ geometries_unique = range_maps_gridded_id %>%
   group_by(geom_id) %>% 
   slice_head(n = 1)
 
-geom_dist = sf::st_distance(geometries_unique, geometries_unique)  # Takes ~ 10 mins
-  %>% as.matrix()
+geom_dist = sf::st_distance(geometries_unique, geometries_unique)  %>%  # Takes ~ 10 mins
+  as.matrix()
 
 range_maps_split = range_maps_gridded_id %>% 
   group_by(name_matched) %>% 
diff --git a/R/03_02_absence_preparation.R b/R/03_02_absence_preparation.R
index 914ddb7..b148c64 100644
--- a/R/03_02_absence_preparation.R
+++ b/R/03_02_absence_preparation.R
@@ -49,9 +49,9 @@ range_maps = st_transform(range_maps, proj_string)
 # geographically close (reproduce sampling biases) but environmentally       #
 # dissimilar (avoid false negatives) to the known occurrences                #
 # ---------------------------------------------------------------------------#
-future::plan("multisession", workers = 16)
+future::plan("multisession", workers = 24)
 
-model_data = furrr::future_walk(
+furrr::future_walk(
   .x = target_species,
   .options = furrr::furrr_options(seed = 42, scheduling = FALSE), # make sure workers get constant stream of work
   .env_globals = c(raster_filepaths, sa_polygon, occs_final, range_maps, proj_string),
@@ -126,13 +126,13 @@ model_data = furrr::future_walk(
                       y = y + runif(nrow(.), -5000, 5000)) %>%     # Add jitter (res/2) to cell centroids 
         st_as_sf(coords = c("x", "y"), crs = proj_string)
       
-      sample_abs = bind_cols(
-        terra::extract(raster_data, terra::vect(sample_points), ID = FALSE)
-      ) %>% 
-        bind_cols(sample_points) %>% 
+      sample_abs = sample_points %>% 
+        bind_cols(
+          terra::extract(raster_data, terra::vect(sample_points), ID = FALSE)
+        ) %>% 
         drop_na()
       
-      abs_spec_list[[length(abs_spec_list)+1]] = sample_points
+      abs_spec_list[[length(abs_spec_list)+1]] = sample_abs
       
       samples_required = samples_required - nrow(sample_points) # Sometimes there are no env data for sample points, so keep sampling
     }
@@ -154,7 +154,7 @@ model_data = furrr::future_walk(
     # Create presence-absence dataframe
     pa_spec = occs_spec %>% 
       dplyr::mutate(present = 1) %>% 
-      bind_rows(abs_spec) 
+      bind_rows(abs_spec)
     
     ggplot() +
       ggtitle(species) +
@@ -189,7 +189,6 @@ model_data = furrr::future_walk(
     })
     
     pa_spec$fold_eval = folds
-    pa_spec$geometry = NULL
     
     save(pa_spec, file = paste0("data/r_objects/pa_sampling/", species, ".RData"))
   }
diff --git a/R/04_01_ssdm_modeling.R b/R/04_01_ssdm_modeling.R
index 5a0de5e..914c1aa 100644
--- a/R/04_01_ssdm_modeling.R
+++ b/R/04_01_ssdm_modeling.R
@@ -1,27 +1,30 @@
 library(furrr)
 library(progressr)
 library(dplyr)
+library(cito)
 
 source("R/utils.R")
 
 # ----------------------------------------------------------------------#
 # Define Training function                                           ####
 # ----------------------------------------------------------------------#
-train_models = function(data_split){
-  future::plan("multisession", workers = 16)
-  p = progressor(along = data_split)
+train_models = function(pa_files){
+  future::plan("multisession", workers = 4)
+  p = progressor(along = pa_files)
   
-  ssdm_results_list = furrr::future_map(
-    .x = data_split, 
+  furrr::future_walk(
+    .x = pa_files, 
     .options = furrr::furrr_options(
       seed = 123,
-      packages = c("dplyr", "tidyr", "sf", "caret", "cito", "pROC")
+      packages = c("dplyr", "tidyr", "sf", "caret", "cito", "pROC"),
+      scheduling = FALSE  # make sure workers get constant stream of work
     ),
-    .f = function(data_spec){
-      spec = data_spec$species[1]
-      p(sprintf("spec=%s", spec))
+    .f = function(pa_file){
+      load(pa_file)
+      species = pa_spec$species[1]
+      p(sprintf("species=%s", species))
       
-      if(all(is.na(data_spec$fold_eval))){
+      if(all(is.na(pa_spec$fold_eval))){
         warning("Too few samples")
         return()
       }
@@ -37,12 +40,12 @@ train_models = function(data_split){
       )
       
       # Create factor presence column 
-      data_spec$present_fct = factor(data_spec$present, levels = c("0", "1"), labels = c("A", "P"))
+      pa_spec$present_fct = factor(pa_spec$present, levels = c("0", "1"), labels = c("A", "P"))
       
       # Outer CV loop (for averaging performance metrics)
-      performance_cv = lapply(sort(unique(data_spec$fold_eval)), function(k){
-        data_test = dplyr::filter(data_spec, fold_eval == k)
-        data_train = dplyr::filter(data_spec, fold_eval != k)
+      performance_cv = lapply(sort(unique(pa_spec$fold_eval)), function(k){
+        data_test = dplyr::filter(pa_spec, fold_eval == k)
+        data_train = dplyr::filter(pa_spec, fold_eval != k)
         
         if(nrow(data_test) == 0 || nrow(data_train) == 0){
           warning("Too few samples")
@@ -75,7 +78,7 @@ train_models = function(data_split){
         )
         
         # Define predictors
-        predictors = paste0("layer_", 1:19)
+        predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness")
         
         # Random Forest #####
         rf_performance = tryCatch({
@@ -84,7 +87,6 @@ train_models = function(data_split){
             x = data_train[, predictors],
             y = data_train$present_fct,
             method = "rf",
-            metric = "AUC",
             trControl = train_ctrl,
             tuneLength = 4,
             verbose = F
@@ -101,7 +103,6 @@ train_models = function(data_split){
             x = data_train[, predictors],
             y = data_train$present_fct,
             method = "gbm",
-            metric = "AUC",
             trControl = train_ctrl,
             tuneLength = 4,
             verbose = F
@@ -117,7 +118,6 @@ train_models = function(data_split){
             x = data_train[, predictors],
             y = data_train$present_fct,
             method = "gamSpline",
-            metric = "AUC",
             tuneLength = 4,
             trControl = train_ctrl
           )
@@ -136,7 +136,7 @@ train_models = function(data_split){
             hidden = c(100L, 100L, 100L),
             loss = "binomial",
             activation = c("leaky_relu", "leaky_relu", "leaky_relu"),
-            epochs = 400L, 
+            epochs = 500L, 
             burnin = 100L,
             lr = 0.001,   
             batchsize = max(nrow(data_train)/10, 64),
@@ -150,14 +150,18 @@ train_models = function(data_split){
             plot = F
           )
           
-          evaluate_model(nn_fit, data_test)
+          if(nn_fit$successfull == 1){
+            evaluate_model(nn_fit, data_test)  
+          } else {
+            na_performance
+          }
         }, error = function(e){
           na_performance
         })
         
         # Summarize results
         performance_summary = tibble(
-          species = spec,
+          species = !!species,
           obs = nrow(data_train),
           fold_eval = k,
           model = c("RF", "GBM", "GAM", "NN"),
@@ -175,8 +179,7 @@ train_models = function(data_split){
       
       # Combine and save evaluation results
       performance_spec = bind_rows(performance_cv)
-      filename = paste0(gsub(" ", "_", spec), "_results.RData")
-      save(performance_spec, file = file.path("data/r_objects/nested_cv/model_results/", filename))
+      save(performance_spec, file = paste0("data/r_objects/model_results/", species, ".RData"))
     }
   )
 }
@@ -184,22 +187,22 @@ train_models = function(data_split){
 # ----------------------------------------------------------------------#
 # Run training                                                       ####
 # ----------------------------------------------------------------------#
-load("data/r_objects/nested_cv/model_data.RData")
-
 handlers(global = TRUE)
 handlers("progress")
 
-data_split = group_split(model_data, species)
+pa_files = list.files("data/r_objects/pa_sampling/", full.names = T)
+species_processed = stringr::str_remove_all(list.files("data/r_objects/model_results/"), ".RData")
+pa_files_run = pa_files[!sapply(pa_files, function(x){any(stringr::str_detect(x, species_processed))})]
 
-train_models(data_split)
+train_models(pa_files_run)
 
 # ----------------------------------------------------------------------#
 # Combine results                                                    ####
 # ----------------------------------------------------------------------#
-files = list.files("data/r_objects/nested_cv/model_results/", full.names = T)
+files = list.files("data/r_objects/model_results/", full.names = T)
 ssdm_results = lapply(files, function(f){load(f); return(performance_spec)}) %>% 
-  bind_rows() %>% 
-  dplyr::group_by(species, model, metric) %>% 
-  dplyr::summarize(value = mean(value, na.rm =T))
+  bind_rows() 
+  #dplyr::group_by(species, model, metric) %>% 
+  #dplyr::summarize(value = mean(value, na.rm =T))
 
-save(ssdm_results, file = "data/r_objects/nested_cv/ssdm_results.RData")
+save(ssdm_results, file = "data/r_objects/ssdm_results.RData")
diff --git a/R/04_03_msdm_embed_raw.R b/R/04_03_msdm_embed_raw.R
index 03a8737..6b2df02 100644
--- a/R/04_03_msdm_embed_raw.R
+++ b/R/04_03_msdm_embed_raw.R
@@ -4,36 +4,40 @@ library(cito)
 
 source("R/utils.R")
 
-load("data/r_objects/model_data.RData")
+load("data/r_objects/model_data_pa_sampling.RData")
 
-model_data$species_int = as.integer(as.factor(model_data$species))
-
-train_data = dplyr::filter(model_data, train == 1)
-test_data = dplyr::filter(model_data, train == 0)
+model_data = model_data %>% 
+  dplyr::mutate(species_int = as.integer(as.factor(model_data$species))) %>% 
+  sf::st_drop_geometry()
+  
+test_data = dplyr::filter(model_data, fold_eval == 1) %>% 
+  dplyr::select(-fold_eval)
+train_data = dplyr::filter(model_data, fold_eval != 1) %>% 
+  dplyr::select(-fold_eval)
 
 # ----------------------------------------------------------------------#
 # Train model                                                        ####
 # ----------------------------------------------------------------------#
-predictors = paste0("layer_", 1:19)
-formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, dim = 50, lambda = 0.000001)"))
+predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness")
+formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, dim = 50, train = T, lambda = 0.0001)"))
 
-plot(1, type="n", xlab="", ylab="", xlim=c(0, 15000), ylim=c(0, 0.7)) # empty plot with better limits, draw points in there
+plot(1, type="n", xlab="", ylab="", xlim=c(0, 25000), ylim=c(0, 0.7)) # empty plot with better limits, draw points in there
 msdm_fit_embedding_raw = dnn(
   formula,
   data = train_data,
   hidden = c(500L, 500L, 500L),
   loss = "binomial",
   activation = c("sigmoid", "leaky_relu", "leaky_relu"),
-  epochs = 15000L, 
+  epochs = 10L, 
   lr = 0.01,   
   baseloss = 1,
   batchsize = nrow(train_data),
   dropout = 0.1,
-  burnin = 100,
+  burnin = 500,
   optimizer = config_optimizer("adam", weight_decay = 0.001),
-  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
-  early_stopping = 250,
-  validation = 0.3,
+  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 150, factor = 0.7),
+  early_stopping = 400,
+  validation = 0.2,
   device = "cuda"
 )
 
@@ -44,20 +48,22 @@ save(msdm_fit_embedding_raw, file = "data/r_objects/msdm_fit_embedding_raw.RData
 # ----------------------------------------------------------------------#
 load("data/r_objects/msdm_fit_embedding_raw.RData")
 
-data_split = split(model_data, model_data$species)
+data_split = test_data %>% 
+  split(test_data$species)
 
 msdm_results_embedding_raw = lapply(data_split, function(data_spec){
-  test_data = dplyr::filter(data_spec, train == 0) %>% 
+  species = data_spec$species[1]
+  data_spec = data_spec %>% 
     dplyr::select(-species)
   
   msdm_performance = tryCatch({
-    evaluate_model(msdm_fit_embedding_raw, test_data)
+    evaluate_model(msdm_fit_embedding_raw, data_spec)
   }, error = function(e){
     list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
   })
   
   performance_summary = tibble(
-    species = data_spec$species[1],
+    species = !!species,
     obs = nrow(data_spec),
     model = "MSDM_embed",
     auc = msdm_performance$AUC,
diff --git a/R/04_07_msdm_embed_multi_nolonlat.R b/R/04_07_msdm_embed_multi_nolonlat.R
index 2044e03..15d37e9 100644
--- a/R/04_07_msdm_embed_multi_nolonlat.R
+++ b/R/04_07_msdm_embed_multi_nolonlat.R
@@ -1,10 +1,11 @@
 library(dplyr)
 library(tidyr)
 library(cito)
+library(sf)
 
 source("R/utils.R")
 
-load("data/r_objects/model_data.RData")
+load("data/r_objects/model_data_pa_sampling.RData")
 load("data/r_objects/func_dist.RData")
 load("data/r_objects/phylo_dist.RData")
 load("data/r_objects/range_dist.RData")
@@ -18,11 +19,13 @@ model_species = Reduce(
 ) 
 
 model_data_final = model_data %>%
+  sf::st_drop_geometry() %>% 
   dplyr::filter(species %in% !!model_species) %>% 
   dplyr::mutate(species_int = as.integer(as.factor(species)))
+  
 
-train_data = dplyr::filter(model_data_final, train == 1)
-test_data = dplyr::filter(model_data_final, train == 0)
+test_data = dplyr::filter(model_data_final, fold_eval == 1)
+train_data = dplyr::filter(model_data_final, fold_eval != 1)
 
 # Create embeddings
 func_ind = match(model_species, colnames(func_dist))
@@ -41,14 +44,15 @@ range_embeddings = eigen(range_dist)$vectors[,1:20]
 # ----------------------------------------------------------------------#
 # Train model                                                        ####
 # ----------------------------------------------------------------------#
-predictors = paste0("layer_", 1:19)
+# Define predictors
+predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness")
 
 formula = as.formula(
   paste0("present ~ ", 
          paste(predictors, collapse = '+'),  
-         " + e(species_int, weights = func_embeddings, lambda = 0.00001, train = F)",
-         " + e(species_int, weights = phylo_embeddings, lambda = 0.00001, train = F)",
-         " + e(species_int, weights = range_embeddings, lambda = 0.00001, train = F)"
+         " + e(species_int, weights = func_embeddings, lambda = 0.00001, train = T)",
+         " + e(species_int, weights = phylo_embeddings, lambda = 0.00001, train = T)",
+         " + e(species_int, weights = range_embeddings, lambda = 0.00001, train = T)"
   )
 )
 
@@ -56,7 +60,7 @@ plot(1, type="n", xlab="", ylab="", xlim=c(0, 25000), ylim=c(0, 0.7)) # empty pl
 msdm_fit_embedding_multi_nolonlat = dnn(
   formula,
   data = train_data,
-  hidden = c(500L, 500L, 500L),
+  hidden = c(200L, 200L, 200L),
   loss = "binomial",
   activation = c("sigmoid", "leaky_relu", "leaky_relu"),
   epochs = 30000L, 
diff --git a/R/04_07_msdm_oneoff.R b/R/04_07_msdm_oneoff.R
new file mode 100644
index 0000000..4ef2a03
--- /dev/null
+++ b/R/04_07_msdm_oneoff.R
@@ -0,0 +1,111 @@
+library(dplyr)
+library(tidyr)
+library(cito)
+library(sf)
+
+source("R/utils.R")
+
+load("data/r_objects/model_data_pa_sampling.RData")
+load("data/r_objects/func_dist.RData")
+load("data/r_objects/phylo_dist.RData")
+load("data/r_objects/range_dist.RData")
+
+# ----------------------------------------------------------------------#
+# Prepare data                                                       ####
+# ----------------------------------------------------------------------#
+model_species = Reduce(
+  intersect, 
+  list(unique(model_data$species), colnames(range_dist), colnames(phylo_dist), colnames(func_dist))
+) 
+
+model_data_final = model_data %>%
+  sf::st_drop_geometry() %>% 
+  dplyr::filter(species %in% !!model_species) %>% 
+  dplyr::mutate(species_int = as.integer(as.factor(species)))
+  
+
+test_data = dplyr::filter(model_data_final, fold_eval == 1)
+train_data = dplyr::filter(model_data_final, fold_eval != 1)
+
+# Create embeddings
+func_ind = match(model_species, colnames(func_dist))
+func_dist = func_dist[func_ind, func_ind]
+func_embeddings = eigen(func_dist)$vectors[,1:20]
+
+phylo_ind = match(model_species, colnames(phylo_dist))
+phylo_dist = phylo_dist[phylo_ind, phylo_ind]
+phylo_embeddings = eigen(phylo_dist)$vectors[,1:20]
+
+range_ind = match(model_species, colnames(range_dist))
+range_dist = range_dist[range_ind, range_ind]
+range_embeddings = eigen(range_dist)$vectors[,1:20]
+
+
+# ----------------------------------------------------------------------#
+# Train model                                                        ####
+# ----------------------------------------------------------------------#
+# Define predictors
+predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness")
+
+formula = as.formula(
+  paste0("present ~ ", 
+         paste(predictors, collapse = '+'),  
+         " + e(species_int, weights = func_embeddings, lambda = 0.00001, train = T)",
+         " + e(species_int, weights = phylo_embeddings, lambda = 0.00001, train = T)",
+         " + e(species_int, weights = range_embeddings, lambda = 0.00001, train = T)"
+  )
+)
+
+plot(1, type="n", xlab="", ylab="", xlim=c(0, 25000), ylim=c(0, 0.7)) # empty plot with better limits, draw points in there
+msdm_fit_embedding_multi_nolonlat = dnn(
+  formula,
+  data = train_data,
+  hidden = c(400L, 400L, 400L),
+  loss = "binomial",
+  activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+  epochs = 30000L, 
+  lr = 0.01,   
+  baseloss = 1,
+  batchsize = nrow(train_data),
+  dropout = 0.1,
+  burnin = 100,
+  optimizer = config_optimizer("adam", weight_decay = 0.001),
+  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 150, factor = 0.7),
+  early_stopping = 250,
+  validation = 0.3,
+  device = "cuda",
+)
+save(msdm_fit_embedding_multi_nolonlat, file = "data/r_objects/msdm_fit_embedding_multi_nolonlat.RData")
+
+# ----------------------------------------------------------------------#
+# Evaluate results                                                   ####
+# ----------------------------------------------------------------------#
+load("data/r_objects/msdm_fit_embedding_multi_nolonlat.RData")
+data_split = test_data %>% 
+  group_by(species_int) %>% 
+  group_split()
+
+msdm_results_embedding_multi_nolonlat = lapply(data_split, function(data_spec){
+  target_species =  data_spec$species[1]
+  data_spec = dplyr::select(data_spec, -species)
+  
+  msdm_performance = tryCatch({
+    evaluate_model(msdm_fit_embedding_multi_nolonlat, data_spec)
+  }, error = function(e){
+    list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
+  })
+  
+  performance_summary = tibble(
+    species = !!target_species,
+    obs = length(which(model_data$species == target_species)),
+    model = "MSDM_embed_informed_multi_nolonlat",
+    auc = msdm_performance$AUC,
+    accuracy = msdm_performance$Accuracy,
+    kappa = msdm_performance$Kappa,
+    precision = msdm_performance$Precision,
+    recall = msdm_performance$Recall,
+    f1 = msdm_performance$F1
+  )
+}) %>% bind_rows()
+
+save(msdm_results_embedding_multi_nolonlat, file = "data/r_objects/msdm_results_embedding_multi_nolonlat.RData")
diff --git a/R/05_01_performance_analysis.qmd b/R/05_01_performance_analysis.qmd
index 44b9b8e..17be731 100644
--- a/R/05_01_performance_analysis.qmd
+++ b/R/05_01_performance_analysis.qmd
@@ -11,12 +11,10 @@ library(sf)
 library(plotly)
 library(DT)
 
-load("../data/r_objects/nested_cv/model_data.RData")
-load("../data/r_objects/nested_cv/ssdm_results.RData")
-#load("../data/r_objects/msdm_fit.RData")
-#load("../data/r_objects/msdm_results.RData")
-#load("../data/r_objects/msdm_results_embedding_trained.RData")
-#load("../data/r_objects/msdm_results_multiclass.RData.")
+load("../data/r_objects/model_data_pa_sampling.RData")
+load("../data/r_objects/ssdm_results.RData")
+load("../data/r_objects/msdm_results_embedding_raw.RData")
+
 load("../data/r_objects/range_maps.RData")
 load("../data/r_objects/range_maps_gridded.RData")
 load("../data/r_objects/occs_final.RData")
@@ -25,6 +23,8 @@ sf::sf_use_s2(use_s2 = FALSE)
 ```
 
 ```{r globals, echo = FALSE, include = FALSE}
+
+
 # Count occs per species
 obs_count = model_data %>% 
   sf::st_drop_geometry() %>% 
@@ -52,8 +52,19 @@ lin_regression = function(x, y, family = "binomial"){
   )
 }
 
+msdm_results = msdm_results_embedding_raw %>% 
+  pivot_longer(all_of(c("auc", "accuracy", "kappa", "precision", "recall", "f1")), names_to = "metric", values_to = "value") %>% 
+  dplyr::select(-obs) %>% 
+  dplyr::mutate(
+    fold_eval = 1
+  ) %>% 
+  drop_na()
+
 # Performance table
-performance = bind_rows(ssdm_results) %>%
+performance = ssdm_results %>% 
+  dplyr::select(-obs) %>% 
+  dplyr::filter(fold_eval == 1, species %in% msdm_results$species) %>%  # Only look at first fold
+  bind_rows(msdm_results) %>% 
   ungroup() %>% 
   dplyr::mutate(
     value = case_when(
@@ -66,93 +77,94 @@ performance = bind_rows(ssdm_results) %>%
 focal_metrics = unique(performance$metric)
 ```
 
-## *Summary*
+## Summary
 
-*This document summarizes the performance of different sSDM and mSDM algorithms for `r I(length(unique(performance$species)))` South American mammal species. Model performance is evaluated on `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) and analyzed along five potential influence factors (number of records, range size, range coverage, range coverage bias, and functional group). The comparison of sSDM vs mSDM approaches is of particular interest.*
+This document summarizes the performance of different sSDM and mSDM algorithms for `r I(length(unique(performance$species)))` South American mammal species. Model performance is evaluated on `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) and analyzed along five potential influence factors (number of records, range size, range coverage, range coverage bias, and functional group). The comparison of sSDM vs mSDM approaches is of particular interest.
 
-*Code can be found on [GitLab](https://git.idiv.de/ye87zine/symobio-modeling).*
+Code can be found on [GitLab](https://git.idiv.de/ye87zine/symobio-modeling).
 
-### *Modeling overview:*
+### Modeling overview:
 
-#### *General decisions*
+#### General decisions
 
--   *Randomly sampled pseudo-absences from expanded area of extent of occurrence records (×1.25)*
--   *Balanced presences and absences for each species*
--   *Predictors: all 19 CHELSA bioclim variables*
--   *70/30 Split of training vs. test data (except for NN models)*
+-   Randomly sampled pseudo-absences from expanded area of extent of occurrence records (×1.25)
+-   Balanced presences and absences for each species
+-   Predictors: all 19 CHELSA bioclim variables
+-   70/30 Split of training vs. test data (except for NN models)
 
-#### *sSDM Algorithms*
+#### sSDM Algorithms
 
-*Random Forest (**SSDM_RF**)*
+Random Forest (**SSDM_RF**)
 
--   *Hyperparameter tuning of `mtry`*
--   *Spatial block cross-validation during training*
+-   Hyperparameter tuning of `mtry`
+-   Spatial block cross-validation during training
 
-*Generalized boosted machine (**SSDM_GBM**)*
+Generalized boosted machine (**SSDM_GBM**)
 
--   *Hyperparameter tuning across `n.trees` , `interaction.depth` , `shrinkage`, `n.minobsinnode`*
--   *Spatial block cross-validation during training*
+-   Hyperparameter tuning across `n.trees` , `interaction.depth` , `shrinkage`, `n.minobsinnode`
+-   Spatial block cross-validation during training
 
-*Generalized Linear Model (**SSDM_GLM**)*
+Generalized Linear Model (**SSDM_GLM**)
 
--   *Logistic model with binomial link function*
--   *Spatial block cross-validation during training*
+-   Logistic model with binomial link function
+-   Spatial block cross-validation during training
 
-*Neural Netwok (**SSDM_NN**)*
+Neural Netwok (**SSDM_NN**)
 
--   *Three hidden layers, leaky ReLu activations, binomial loss*
--   *no spatial block cross-validation during training*
+-   Three hidden layers, leaky ReLu activations, binomial loss
+-   no spatial block cross-validation during training
 
-#### *mSDM Algorithms*
+#### mSDM Algorithms
 
-*Binary Neural Network with species embedding (**MSDM_embed**)*
+Binary Neural Network with species embedding (**MSDM_embed**)
 
--   *definition: presence \~ environment + embedding(species)*
--   *prediction: probability of occurrence given a set of (environmental) inputs and species identity*
--   *embedding initialized at random*
--   *three hidden layers, sigmoid + leaky ReLu activations, binomial loss*
+-   definition: presence \~ environment + embedding(species)
+-   prediction: probability of occurrence given a set of (environmental) inputs and species identity
+-   embedding initialized at random
+-   three hidden layers, sigmoid + leaky ReLu activations, binomial loss
 
-*Binary Neural Network with trait-informed species embedding (**MSDM_embed_informed_trained**)*
+Binary Neural Network with trait-informed species embedding (**MSDM_embed_informed_trained**)
 
--   *definition: presence \~ environment + embedding(species)*
--   *prediction: probability of occurrence given a set of (environmental) inputs and species identity*
--   *embedding initialized using eigenvectors of functional distance matrix, then further training on data*
--   *three hidden layers, sigmoid + leaky ReLu activations, binomial loss*
+-   definition: presence \~ environment + embedding(species)
+-   prediction: probability of occurrence given a set of (environmental) inputs and species identity
+-   embedding initialized using eigenvectors of functional distance matrix, then further training on data
+-   three hidden layers, sigmoid + leaky ReLu activations, binomial loss
 
-*Multi-Class Neural Network (**MSDM_multiclass**)*
+Multi-Class Neural Network (**MSDM_multiclass**)
 
--   *definition: species identity \~ environment*
--   *prediction: probability distribution across all observed species given a set of (environmental) inputs*
--   *presence-only data in training*
--   *three hidden layers, leaky ReLu activations, softmax loss*
--   *Top-k based evaluation (k=10, P/A \~ target species in / not among top 10 predictions)*
+-   definition: species identity \~ environment
+-   prediction: probability distribution across all observed species given a set of (environmental) inputs
+-   presence-only data in training
+-   three hidden layers, leaky ReLu activations, softmax loss
+-   Top-k based evaluation (k=10, P/A \~ target species in / not among top 10 predictions)
 
-### *Key findings:*
+### Key findings:
 
--   *sSDM algorithms (RF, GBM) outperformed mSDMs in most cases*
--   *mSDMs showed indications of better performance for rare species (\< 10-20 occurrences)*
--   *More occurrence records and larger range sizes tended to improve model performance*
--   *Higher range coverage correlated with better performance*
--   *Range coverage bias and functional group showed some impact but were less consistent*
--   *Convergence problems hampered NN sSDM performance*
+-   sSDM algorithms (RF, GBM) outperformed mSDMs in most cases
+-   mSDMs showed indications of better performance for rare species (\< 10-20 occurrences)
+-   More occurrence records and larger range sizes tended to improve model performance
+-   Higher range coverage correlated with better performance
+-   Range coverage bias and functional group showed some impact but were less consistent
+-   Convergence problems hampered NN sSDM performance
 
-## *Analysis*
+## Analysis
 
-*The table below shows the analysed modeling results.*
+The table below shows the analysed modeling results.
 
 ```{r performance, echo = FALSE, message=FALSE, warnings=FALSE}
 DT::datatable(performance) %>% 
   formatRound(columns="value", digits=3)
 ```
 
-### *Number of records*
+### Number of records
 
--   *Model performance was generally better for species with more observations*
--   *Very poor performance below 50-100 observations*
+-   Model performance was generally better for species with more observations
+-   Very poor performance below 50-100 observations
 
 ```{r number_of_records, echo = FALSE, message=FALSE, warnings=FALSE}
 plot_performance_over_frequency = function(df_plot, metric) {
-  df_plot = dplyr::filter(df_plot, metric == !!metric)
+
+  df_plot = dplyr::filter(df_plot, metric == !!metric) 
   
   # Calculate regression lines for each model and metric combination
   suppressWarnings({
@@ -208,46 +220,47 @@ plot_performance_over_frequency = function(df_plot, metric) {
   return(plot)
 }
 
-df_plot = performance %>% dplyr::left_join(obs_count, by = "species")
+df_plot = performance %>% dplyr::left_join(obs_count, by = "species") 
+  
 ```
 
 ::: panel-tabset
-#### *AUC*
+#### AUC
 
 ```{r echo = FALSE}
 plot = plot_performance_over_frequency(df_plot, metric = "auc")
 bslib::card(plot, full_screen = T)
 ```
 
-#### *F1*
+#### F1
 
 ```{r echo = FALSE}
 plot = plot_performance_over_frequency(df_plot, metric = "f1")
 bslib::card(plot, full_screen = T)
 ```
 
-#### *Cohen's kappa*
+#### Cohen's kappa
 
 ```{r echo = FALSE}
 plot = plot_performance_over_frequency(df_plot, metric = "kappa")
 bslib::card(plot, full_screen = T)
 ```
 
-#### *Accurracy*
+#### Accurracy
 
 ```{r echo = FALSE}
 plot = plot_performance_over_frequency(df_plot, metric = "accuracy")
 bslib::card(plot, full_screen = T)
 ```
 
-#### *Precision*
+#### Precision
 
 ```{r echo = FALSE}
 plot = plot_performance_over_frequency(df_plot, metric = "precision")
 bslib::card(plot, full_screen = T)
 ```
 
-#### *Recall*
+#### Recall
 
 ```{r echo = FALSE}
 plot = plot_performance_over_frequency(df_plot, metric = "recall")
@@ -255,16 +268,14 @@ bslib::card(plot, full_screen = T)
 ```
 :::
 
+### Range characteristics
 
+#### Range size
 
-### *Range characteristics*
-
-#### *Range size*
-
-*Range size was calculated based on polygon layers from the IUCN Red List of Threatened Species (2016).*
+Range size was calculated based on polygon layers from the IUCN Red List of Threatened Species (2016).
 
--   *Model performance tended to be slightly higher for species with larger range size*
--   *Only RF shows continuous performance improvements beyond range sizes of \~5M km²*
+-   Model performance tended to be slightly higher for species with larger range size
+-   Only RF shows continuous performance improvements beyond range sizes of \~5M km²
 
 ```{r range_size, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
 
@@ -349,15 +360,15 @@ for (model_name in unique(df_plot$model)) {
 bslib::card(plot, full_screen = T)
 ```
 
-#### *Range coverage*
+#### Range coverage
 
-*Species ranges were split into continuous hexagonal grid cells of 1 degree diameter. Range coverage was then calculated as the number of grid cells containing at least one occurrence record divided by the number of total grid cells.*
+Species ranges were split into continuous hexagonal grid cells of 1 degree diameter. Range coverage was then calculated as the number of grid cells containing at least one occurrence record divided by the number of total grid cells.
 
 $$
 RangeCoverage = \frac{N_{cells\_occ}}{N_{cells\_total}}
 $$
 
--   *Models for species with higher range coverage showed slightly better performance*
+-   Models for species with higher range coverage showed slightly better performance
 
 ```{r range_coverage, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
 df_cells_total = range_maps_gridded %>%
@@ -455,17 +466,17 @@ for (model_name in unique(df_plot$model)) {
 bslib::card(plot, full_screen = T)
 ```
 
-#### *Range coverage bias*
+#### Range coverage bias
 
-*Range coverage bias was calculated as 1 minus the ratio of the actual range coverage and the hypothetical range coverage if all observations were maximally spread out across the range.*
+Range coverage bias was calculated as 1 minus the ratio of the actual range coverage and the hypothetical range coverage if all observations were maximally spread out across the range.
 
 $$
 RangeCoverageBias = 1 - \frac{RangeCoverage}{min({N_{obs\_total}} / {N_{cells\_total}}, 1)}
 $$
 
-*Higher bias values indicate that occurrence records are spatially more clustered within the range of the species.*
+Higher bias values indicate that occurrence records are spatially more clustered within the range of the species.
 
--   *There was no strong relationship between range coverage bias and model performance*
+-   There was no strong relationship between range coverage bias and model performance
 
 ```{r range_coverage_bias, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
 df_occs_total = occs_final %>% 
@@ -556,18 +567,18 @@ for (model_name in unique(df_plot$model)) {
 bslib::card(plot, full_screen = T)
 ```
 
-### *Functional group*
+### Functional group
 
-*Functional groups were assigned based on taxonomic order. The following groupings were used:*
+Functional groups were assigned based on taxonomic order. The following groupings were used:
 
-| *Functional group*      | *Taxomic orders*                                                        |
-|-------------------|-----------------------------------------------------|
-| *large ground-dwelling* | *Carnivora, Artiodactyla, Cingulata, Perissodactyla*                    |
-| *small ground-dwelling* | *Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha* |
-| *arboreal*              | *Primates, Pilosa*                                                      |
-| *flying*                | *Chiroptera*                                                            |
+| Functional group      | Taxomic orders                                                        |
+|--------------------|----------------------------------------------------|
+| large ground-dwelling | Carnivora, Artiodactyla, Cingulata, Perissodactyla                    |
+| small ground-dwelling | Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha |
+| arboreal              | Primates, Pilosa                                                      |
+| flying                | Chiroptera                                                            |
 
--   *Models for bats tended to perform slightly worse than for other groups.*
+-   Models for bats tended to perform slightly worse than for other groups.
 
 ```{r functional_groups, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
 df_plot = performance %>% 
diff --git a/R/05_01_performance_analysis_carsten.qmd b/R/05_01_performance_analysis_carsten.qmd
new file mode 100644
index 0000000..7dee2c4
--- /dev/null
+++ b/R/05_01_performance_analysis_carsten.qmd
@@ -0,0 +1,190 @@
+---
+title: "SDM Performance analysis"
+format: html
+editor: visual
+engine: knitr
+---
+
+```{r init, echo = FALSE, include = FALSE}
+library(tidyverse)
+library(sf)
+library(plotly)
+library(DT)
+
+load("../data/r_objects/model_data_pa_sampling.RData")
+load("../data/r_objects/ssdm_results.RData")
+load("../data/r_objects/msdm_results_embedding_raw.RData")
+
+load("../data/r_objects/range_maps.RData")
+load("../data/r_objects/range_maps_gridded.RData")
+load("../data/r_objects/occs_final.RData")
+load("../data/r_objects/functional_groups.RData")
+sf::sf_use_s2(use_s2 = FALSE)
+```
+
+```{r globals, echo = FALSE, include = FALSE}
+
+
+# Count occs per species
+obs_count = model_data %>% 
+  sf::st_drop_geometry() %>% 
+  dplyr::filter(present == 1) %>% 
+  dplyr::group_by(species) %>% 
+  dplyr::summarise(obs = n())
+
+
+# Regression functions
+asym_regression = function(x, y){
+  nls_fit = nls(y ~ 1 - (1-b) * exp(-c * log(x)), start = list(b = 0.1, c = 0.1))
+  new_x = exp(seq(log(min(x)), log(max(x)), length.out = 100))
+  data.frame(
+    x = new_x,
+    fit = predict(nls_fit, newdata = data.frame(x = new_x))
+  )
+}
+
+lin_regression = function(x, y, family = "binomial"){
+  glm_fit = suppressWarnings(glm(y~x, family = family))
+  new_x = seq(min(x), max(x), length.out = 100)
+  data.frame(
+    x = new_x,
+    fit = predict(glm_fit, newdata = data.frame(x = new_x), type = "response")
+  )
+}
+
+msdm_results = msdm_results_embedding_raw %>% 
+  pivot_longer(all_of(c("auc", "accuracy", "kappa", "precision", "recall", "f1")), names_to = "metric", values_to = "value") %>% 
+  dplyr::select(-obs) %>% 
+  dplyr::mutate(
+    fold_eval = 1
+  ) %>% 
+  drop_na()
+
+# Performance table
+performance = ssdm_results %>% 
+  dplyr::select(-obs) %>% 
+  dplyr::filter(fold_eval == 1, species %in% msdm_results$species) %>%  # Only look at first fold
+  bind_rows(msdm_results) %>% 
+  ungroup() %>% 
+  dplyr::mutate(
+    value = case_when(
+      ((is.na(value) | is.nan(value)) & metric %in% c("auc", "f1", "accurracy", "precision", "recall")) ~ 0.5,
+      ((is.na(value) | is.nan(value)) & metric %in% c("kappa")) ~ 0,
+      .default = value
+    )
+  )
+
+focal_metrics = unique(performance$metric)
+```
+
+
+## Analysis
+
+### Number of records
+
+```{r number_of_records, echo = FALSE, message=FALSE, warnings=FALSE}
+plot_performance_over_frequency = function(df_plot, metric) {
+
+  df_plot = dplyr::filter(df_plot, metric == !!metric) 
+  
+  # Calculate regression lines for each model and metric combination
+  suppressWarnings({
+    regression_lines = df_plot %>%
+      group_by(model) %>%
+      group_modify( ~ asym_regression(.x$obs, .x$value))
+  })
+  
+  # Create base plot
+  plot <- plot_ly() %>%
+    layout(
+      title = "Model Performance vs. Number of observations",
+      xaxis = list(title = "Number of observations", type = "log"),
+      yaxis = list(title = metric),
+      legend = list(x = 1.1, y = 0.5), # Move legend to the right of the plot
+      margin = list(r = 150), # Add right margin to accommodate legend
+      hovermode = 'closest'
+    )
+  
+  # Points
+  for (model_name in unique(df_plot$model)) {
+    plot = plot %>%
+      add_markers(
+        data = filter(df_plot, model == model_name, metric %in% focal_metrics),
+        x = ~ obs,
+        y = ~ value,
+        color = model_name, # Set color to match legendgroup
+        legendgroup = model_name,
+        opacity = 0.6,
+        name = ~ model,
+        hoverinfo = 'text',
+        text = ~ paste(
+          "Species:", species, "<br>Observations:", obs, "<br>Value:", round(value, 3)
+        )
+      )
+  }
+  
+  # Add regression lines
+  for (model_name in unique(df_plot$model)) {
+    reg_data = dplyr::filter(regression_lines, model == model_name)
+    plot = plot %>%
+      add_lines(
+        data = reg_data,
+        x = ~ x,
+        y = ~ fit,
+        color = model_name, # Set color to match legendgroup
+        legendgroup = model_name,
+        name = paste(model_name, '(fit)'),
+        showlegend = FALSE
+      )
+  }
+  
+  return(plot)
+}
+
+df_plot = performance %>% dplyr::left_join(obs_count, by = "species") 
+  
+```
+
+::: panel-tabset
+#### AUC
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "auc")
+bslib::card(plot, full_screen = T)
+```
+
+#### F1
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "f1")
+bslib::card(plot, full_screen = T)
+```
+
+#### Cohen's kappa
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "kappa")
+bslib::card(plot, full_screen = T)
+```
+
+#### Accurracy
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "accuracy")
+bslib::card(plot, full_screen = T)
+```
+
+#### Precision
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "precision")
+bslib::card(plot, full_screen = T)
+```
+
+#### Recall
+
+```{r echo = FALSE}
+plot = plot_performance_over_frequency(df_plot, metric = "recall")
+bslib::card(plot, full_screen = T)
+```
+:::
diff --git a/R/_publish.yml b/R/_publish.yml
index 02920d1..c07bacd 100644
--- a/R/_publish.yml
+++ b/R/_publish.yml
@@ -2,3 +2,7 @@
   quarto-pub:
     - id: 309c45c3-1515-4985-a435-9ffa1888c5e7
       url: 'https://chrkoenig.quarto.pub/ssdm-performance-analysis-0a06'
+- source: 05_01_performance_analysis_carsten.qmd
+  quarto-pub:
+    - id: f9cafa63-bfd4-4c00-b401-09ec137bd3ce
+      url: 'https://chrkoenig.quarto.pub/sdm-performance-carsten'
diff --git a/R/utils.R b/R/utils.R
index a7436a5..1c9dc02 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -40,7 +40,7 @@ evaluate_model <- function(model, data) {
   
   # Predict probabilities
   if(class(model) %in% c("citodnn", "citodnnBootstrap")){
-    probs <- predict(model, data, type = "response")[,1]
+    probs <- predict(model, as.matrix(data), type = "response")[,1]
     preds <- factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
   } else {
     probs <- predict(model, data, type = "prob")$P
diff --git a/renv.lock b/renv.lock
index ff0fc5a..7f13630 100644
--- a/renv.lock
+++ b/renv.lock
@@ -2,6 +2,26 @@
   "R": {
     "Version": "4.3.0",
     "Repositories": [
+      {
+        "Name": "BioCsoft",
+        "URL": "https://bioconductor.org/packages/3.18/bioc"
+      },
+      {
+        "Name": "BioCann",
+        "URL": "https://bioconductor.org/packages/3.18/data/annotation"
+      },
+      {
+        "Name": "BioCexp",
+        "URL": "https://bioconductor.org/packages/3.18/data/experiment"
+      },
+      {
+        "Name": "BioCworkflows",
+        "URL": "https://bioconductor.org/packages/3.18/workflows"
+      },
+      {
+        "Name": "BioCbooks",
+        "URL": "https://bioconductor.org/packages/3.18/books"
+      },
       {
         "Name": "CRAN",
         "URL": "https://cloud.r-project.org"
@@ -177,6 +197,20 @@
       ],
       "Hash": "f27411eb6d9c3dada5edd444b8416675"
     },
+    "RcppArmadillo": {
+      "Package": "RcppArmadillo",
+      "Version": "14.2.0-1",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Requirements": [
+        "R",
+        "Rcpp",
+        "methods",
+        "stats",
+        "utils"
+      ],
+      "Hash": "ce9c35ab9ea9c6ef4e27a08868cdb32b"
+    },
     "RcppEigen": {
       "Package": "RcppEigen",
       "Version": "0.3.4.0.2",
@@ -190,6 +224,57 @@
       ],
       "Hash": "4ac8e423216b8b70cb9653d1b3f71eb9"
     },
+    "RcppGSL": {
+      "Package": "RcppGSL",
+      "Version": "0.3.13",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Requirements": [
+        "Rcpp",
+        "stats"
+      ],
+      "Hash": "e8fc7310d256a7b6c4de8e57ab76c55d"
+    },
+    "RcppParallel": {
+      "Package": "RcppParallel",
+      "Version": "5.1.9",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Requirements": [
+        "R"
+      ],
+      "Hash": "f38a72a419b91faac0ce5d9eee04c120"
+    },
+    "RcppZiggurat": {
+      "Package": "RcppZiggurat",
+      "Version": "0.1.6",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Requirements": [
+        "R",
+        "Rcpp",
+        "RcppGSL",
+        "graphics",
+        "parallel",
+        "stats",
+        "utils"
+      ],
+      "Hash": "75b4a36aeeed440ad03b996081190703"
+    },
+    "Rfast": {
+      "Package": "Rfast",
+      "Version": "2.1.0",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Requirements": [
+        "R",
+        "Rcpp",
+        "RcppArmadillo",
+        "RcppParallel",
+        "RcppZiggurat"
+      ],
+      "Hash": "79e8394e1fa06a4ae954b70ca9b16e8f"
+    },
     "SQUAREM": {
       "Package": "SQUAREM",
       "Version": "2021.1",
@@ -832,6 +917,21 @@
       ],
       "Hash": "33698c4b3127fc9f506654607fb73676"
     },
+    "dismo": {
+      "Package": "dismo",
+      "Version": "1.3-16",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Requirements": [
+        "R",
+        "Rcpp",
+        "methods",
+        "raster",
+        "sp",
+        "terra"
+      ],
+      "Hash": "5c8646b40ba69146afc070aaea9f893c"
+    },
     "doParallel": {
       "Package": "doParallel",
       "Version": "1.0.17",
@@ -1139,6 +1239,17 @@
       ],
       "Hash": "15e9634c0fcd294799e9b2e929ed1b86"
     },
+    "geos": {
+      "Package": "geos",
+      "Version": "0.2.4",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Requirements": [
+        "libgeos",
+        "wk"
+      ],
+      "Hash": "117a3a09b793abf1d2146027a71b7524"
+    },
     "geosphere": {
       "Package": "geosphere",
       "Version": "1.5-18",
@@ -1623,6 +1734,13 @@
       ],
       "Hash": "d908914ae53b04d4c0c0fd72ecc35370"
     },
+    "libgeos": {
+      "Package": "libgeos",
+      "Version": "3.11.1-2",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Hash": "323d0f39c2e5ebcb152b810d1e8ed9bb"
+    },
     "lifecycle": {
       "Package": "lifecycle",
       "Version": "1.0.4",
@@ -2660,6 +2778,18 @@
       ],
       "Hash": "75940133cca2e339afce15a586f85b11"
     },
+    "spatialEco": {
+      "Package": "spatialEco",
+      "Version": "2.0-2",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Requirements": [
+        "R",
+        "sf",
+        "terra"
+      ],
+      "Hash": "d0352030add66f0c73ff5d7473b7aef1"
+    },
     "stringdist": {
       "Package": "stringdist",
       "Version": "0.9.12",
@@ -2847,6 +2977,27 @@
       ],
       "Hash": "829f27b9c4919c16b593794a6344d6c0"
     },
+    "tidyterra": {
+      "Package": "tidyterra",
+      "Version": "0.6.1",
+      "Source": "Repository",
+      "Repository": "CRAN",
+      "Requirements": [
+        "R",
+        "cli",
+        "data.table",
+        "dplyr",
+        "ggplot2",
+        "magrittr",
+        "rlang",
+        "scales",
+        "sf",
+        "terra",
+        "tibble",
+        "tidyr"
+      ],
+      "Hash": "5d9672bc78297d33c4ad8b51a8aaa48f"
+    },
     "tidyverse": {
       "Package": "tidyverse",
       "Version": "2.0.0",
-- 
GitLab