diff --git a/R/04_01_ssdm_modeling.R b/R/04_01_ssdm_modeling.R
index ddef6965715f976e9bf90267a213daa1c7ec4585..efc6107302149c4b33aa8b0afc698e1e3fedcfd9 100644
--- a/R/04_01_ssdm_modeling.R
+++ b/R/04_01_ssdm_modeling.R
@@ -1,170 +1,194 @@
-library(dplyr)
-library(tidyr)
 library(furrr)
-library(caret)
-library(cito)
-library(pROC)
+library(progressr)
+library(dplyr)
 
 source("R/utils.R")
 
-load("data/r_objects/full_evaluation/model_data.RData")
-
-data_split = split(model_data, model_data$species)
-
-# Define empty result for performance eval
-performance = list(    
-  AUC = numeric(0),
-  Accuracy = numeric(0),
-  Kappa = numeric(0),
-  Precision = numeric(0),
-  Recall = numeric(0),
-  F1 = numeric(0)
-)
-
 # ----------------------------------------------------------------------#
-# Train models                                                       ####
+# Define Training function                                           ####
 # ----------------------------------------------------------------------#
-future::plan("multisession", workers = 8)
-ssdm_results = furrr::future_map(data_split, .options = furrr::furrr_options(seed = 123), .f = function(data_spec){
-  data_spec$present_fct = factor(data_spec$present, levels = c("0", "1"), labels = c("A", "P"))
+train_models = function(data_split){
+  future::plan("multisession", workers = 16)
+  p = progressor(along = data_split)
   
-  # Outer CV loop
-  for(k in sort(unique(data_spec$fold_eval))){
-    data_test = dplyr::filter(data_spec, fold_eval == k)
-    data_train = dplyr::filter(data_spec, fold_eval != k)
-    
-    if(nrow(data_test) == 0 || nrow(data_train) == 0){
-      return()
-    }
-    
-    # Create inner CV folds for model training
-    cv_train = blockCV::cv_spatial(
-      data_train,
-      column = "present",
-      k = 5,
-      progress = F, plot = F, report = F
-    )
-    data_train$fold_train = cv_train$folds_ids
-    
-    # Define caret training routine #####
-    index_train = lapply(unique(sort(data_train$fold_train)), function(x){
-      return(which(data_train$fold_train != x))
-    })
-    
-    train_ctrl = trainControl(
-      search = "grid",
-      classProbs = TRUE, 
-      index = index_train,
-      summaryFunction = twoClassSummary, 
-      savePredictions = "final"
-    )
-  
-    # Define predictors
-    predictors = paste0("layer_", 1:19)
-    
-    # Random Forest #####
-    rf_performance = tryCatch({
-      rf_grid = expand.grid(
-        mtry = c(3,7,11,15,19)                # Number of randomly selected predictors
-      )
+  ssdm_results_list = furrr::future_map(
+    .x = data_split, 
+    .options = furrr::furrr_options(
+      seed = 123,
+      packages = c("dplyr", "tidyr", "sf", "caret", "cito", "pROC")
+    ),
+    .f = function(data_spec){
+      spec = data_spec$species[1]
+      p(sprintf("spec=%s", spec))
       
-      rf_fit = caret::train(
-        x = data_train[, predictors],
-        y = data_train$present_fct,
-        method = "rf",
-        metric = "ROC",
-        tuneGrid = rf_grid,
-        trControl = train_ctrl
-      )
-      evaluate_model(rf_fit, data_train)
-    }, error = function(e){
-      na_performance
-    })
-    
-    # Gradient Boosted Machine ####
-    gbm_performance = tryCatch({
-      gbm_grid <- expand.grid(
-        n.trees = c(100, 500, 1000, 1500),       # number of trees
-        interaction.depth = c(3, 5, 7),          # Maximum depth of each tree
-        shrinkage = c(0.01, 0.005, 0.001),       # Lower learning rates
-        n.minobsinnode = c(10, 20)               # Minimum number of observations in nodes
-      )
+      if(all(is.na(data_spec$fold_eval))){
+        warning("Too few samples")
+        return()
+      }
       
-      gbm_fit = train(
-        x = data_train[, predictors],
-        y = data_train$present_fct,
-        method = "gbm",
-        metric = "ROC",
-        verbose = F,
-        tuneGrid = gbm_grid,
-        trControl = train_ctrl
-      )
-      evaluate_model(gbm_fit, data_test)
-    }, error = function(e){
-      na_performance
-    })
-    
-    # Generalized additive Model ####
-    glm_performance = tryCatch({
-      glm_fit = train(
-        x = data_train[, predictors],
-        y = data_train$present_fct,
-        method = "glm",
-        family=binomial, 
-        metric = "ROC",
-        preProcess = c("center", "scale"),
-        trControl = train_ctrl
+      # Define empty result for performance eval
+      na_performance = list(    
+        AUC = NA,
+        Accuracy = NA,
+        Kappa = NA,
+        Precision = NA,
+        Recall = NA,
+        F1 = NA
       )
-      evaluate_model(glm_fit, data_test)
-    }, error = function(e){
-      na_performance
-    })
-    
-    # Neural Network ####
-    nn_performance = tryCatch({
-      predictors = paste0("layer_", 1:19)
-      formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+')))
       
-      nn_fit = dnn(
-        formula,
-        data = data_train,
-        hidden = c(500L, 500L, 500L),
-        loss = "binomial",
-        activation = c("sigmoid", "leaky_relu", "leaky_relu"),
-        epochs = 500L, 
-        lr = 0.001,   
-        baseloss = 1,
-        batchsize = nrow(data_train),
-        dropout = 0.1,
-        burnin = 100,
-        optimizer = config_optimizer("adam", weight_decay = 0.001),
-        lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
-        early_stopping = 250,
-        validation = 0.3,
-        device = "cuda"
-      )
+      # Create factor presence column 
+      data_spec$present_fct = factor(data_spec$present, levels = c("0", "1"), labels = c("A", "P"))
       
-      evaluate_model(nn_fit, data_train)
-    }, error = function(e){
-      na_performance
-    })
-    
-    # Summarize results
-    performance_summary = tibble(
-      species = data_train$species[1],
-      obs = nrow(data_train),
-      model = c("SSDM_RF", "SSDM_GBM", "SSDM_GLM", "SSDM_NN"),
-      auc = c(rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC, nn_performance$AUC),
-      accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy, nn_performance$Accuracy),
-      kappa = c(rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa, nn_performance$Kappa),
-      precision = c(rf_performance$Precision, gbm_performance$Precision, glm_performance$Precision, nn_performance$Precision),
-      recall = c(rf_performance$Recall, gbm_performance$Recall, glm_performance$Recall, nn_performance$Recall),
-      f1 = c(rf_performance$F1, gbm_performance$F1, glm_performance$F1, nn_performance$F1)
-    )
-  }
-  return(performance_summary)
-})
+      # Outer CV loop (for averaging performance metrics)
+      performance_cv = lapply(sort(unique(data_spec$fold_eval)), function(k){
+        data_test = dplyr::filter(data_spec, fold_eval == k)
+        data_train = dplyr::filter(data_spec, fold_eval != k)
+        
+        if(nrow(data_test) == 0 || nrow(data_train) == 0){
+          warning("Too few samples")
+        }
+        
+        # Create inner CV folds for model training
+        cv_train = blockCV::cv_spatial(
+          data_train,
+          column = "present",
+          k = 5,
+          progress = F, plot = F, report = F
+        )
+        data_train$fold_train = cv_train$folds_ids
+        
+        # Drop geometry
+        data_train$geometry = NULL
+        data_test$geometry = NULL
+        
+        # Define caret training routine #####
+        index_train = lapply(unique(sort(data_train$fold_train)), function(x){
+          return(which(data_train$fold_train != x))
+        })
+        
+        train_ctrl = trainControl(
+          search = "random",
+          classProbs = TRUE, 
+          index = index_train,
+          summaryFunction = twoClassSummary, 
+          savePredictions = "final",
+        )
+        
+        # Define predictors
+        predictors = paste0("layer_", 1:19)
+        
+        # Random Forest #####
+        rf_performance = tryCatch({
+          # Fit model
+          rf_fit = caret::train(
+            x = data_train[, predictors],
+            y = data_train$present_fct,
+            method = "rf",
+            metric = "AUC",
+            trControl = train_ctrl,
+            tuneLength = 4,
+            verbose = F
+          )
+          
+          evaluate_model(rf_fit, data_test)
+        }, error = function(e){
+          na_performance
+        })
+        
+        # Gradient Boosted Machine ####
+        gbm_performance = tryCatch({
+          gbm_fit = train(
+            x = data_train[, predictors],
+            y = data_train$present_fct,
+            method = "gbm",
+            metric = "AUC",
+            trControl = train_ctrl,
+            tuneLength = 4,
+            verbose = F
+          )
+          evaluate_model(gbm_fit, data_test)
+        }, error = function(e){
+          na_performance
+        })
+        
+        # Generalized Additive Model ####
+        gam_performance = tryCatch({
+          gam_fit = train(
+            x = data_train[, predictors],
+            y = data_train$present_fct,
+            method = "gamSpline",
+            metric = "AUC",
+            tuneLength = 4,
+            trControl = train_ctrl
+          )
+          evaluate_model(gam_fit, data_test)
+        }, error = function(e){
+          na_performance
+        })
+        
+        # Neural Network ####
+        nn_performance = tryCatch({
+          formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+')))
+          
+          nn_fit = dnn(
+            formula,
+            data = data_train,
+            hidden = c(100L, 100L, 100L),
+            loss = "binomial",
+            activation = c("leaky_relu", "leaky_relu", "leaky_relu"),
+            epochs = 400L, 
+            burnin = 100L,
+            lr = 0.001,   
+            batchsize = max(nrow(data_train)/10, 64),
+            dropout = 0.1,
+            optimizer = config_optimizer("adam", weight_decay = 0.001),
+            lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 50, factor = 0.7),
+            early_stopping = 100,
+            validation = 0.2,
+            device = "cuda",
+            verbose = F,
+            plot = F
+          )
+          
+          evaluate_model(nn_fit, data_test)
+        }, error = function(e){
+          na_performance
+        })
+        
+        # Summarize results
+        performance_summary = tibble(
+          species = spec,
+          obs = nrow(data_train),
+          fold_eval = k,
+          model = c("RF", "GBM", "GAM", "NN"),
+          auc = c(rf_performance$AUC, gbm_performance$AUC, gam_performance$AUC, nn_performance$AUC),
+          accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, gam_performance$Accuracy, nn_performance$Accuracy),
+          kappa = c(rf_performance$Kappa, gbm_performance$Kappa, gam_performance$Kappa, nn_performance$Kappa),
+          precision = c(rf_performance$Precision, gbm_performance$Precision, gam_performance$Precision, nn_performance$Precision),
+          recall = c(rf_performance$Recall, gbm_performance$Recall, gam_performance$Recall, nn_performance$Recall),
+          f1 = c(rf_performance$F1, gbm_performance$F1, gam_performance$F1, nn_performance$F1)
+        ) %>% 
+          pivot_longer(all_of(c("auc", "accuracy", "kappa", "precision", "recall", "f1")), names_to = "metric", values_to = "value")
+        
+        return(performance_summary)
+      })
+      
+      # Combine and save evaluation results
+      performance_spec = bind_rows(performance_cv)
+      filename = paste0(gsub(" ", "_", spec), "_results.RData")
+      save(performance_spec, file = file.path("data/r_objects/nested_cv/model_results/", filename))
+    }
+  )
+}
+
+# ----------------------------------------------------------------------#
+# Run training                                                       ####
+# ----------------------------------------------------------------------#
+load("data/r_objects/nested_cv/model_data.RData")
+
+handlers(global = TRUE)
+handlers("progress")
 
-ssdm_results = bind_rows(ssdm_results)
+data_split = group_split(model_data, species)
 
-save(ssdm_results, file = "data/r_objects/ssdm_results.RData")
+train_models(data_split)
\ No newline at end of file
diff --git a/R/05_02_MSDM_comparison.qmd b/R/05_02_MSDM_comparison.qmd
index 5b14c8a7d1241af5315de31327fbd91dc067e89b..b70af796777e659e9da4f3b62acda73a5ad89c2c 100644
--- a/R/05_02_MSDM_comparison.qmd
+++ b/R/05_02_MSDM_comparison.qmd
@@ -12,13 +12,15 @@ library(plotly)
 library(DT)
 library(shiny)
 
-load("../data/r_objects/msdm_results_embedding_raw.RData")
-load("../data/r_objects/msdm_results_embedding_traits_static.RData")
-load("../data/r_objects/msdm_results_embedding_traits_trained.RData")
-load("../data/r_objects/msdm_results_embedding_phylo_static.RData")
-load("../data/r_objects/msdm_results_embedding_phylo_trained.RData")
-load("../data/r_objects/msdm_results_embedding_range_static.RData")
-load("../data/r_objects/msdm_results_embedding_range_trained.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_raw.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_traits_static.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_traits_trained.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_phylo_static.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_phylo_trained.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_range_static.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_range_trained.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_multi_nolonlat.RData")
+load("../data/r_objects/simple_cv/msdm_results_embedding_multi_lonlat.RData")
 
 sf::sf_use_s2(use_s2 = FALSE)
 ```
@@ -40,7 +42,9 @@ results_embedding_informed = c(
   "msdm_results_embedding_range_static",
   "msdm_results_embedding_range_trained",
   "msdm_results_embedding_traits_static",
-  "msdm_results_embedding_traits_trained"
+  "msdm_results_embedding_traits_trained",
+  "msdm_results_embedding_multi_nolonlat",
+  "msdm_results_embedding_multi_lonlat"
 )
 
 results_embedding_informed_merged = lapply(results_embedding_informed, function(df_name){
@@ -63,7 +67,9 @@ results_final = results_embedding_raw %>%
       "MSDM_embed_informed_traits_static" = "M_TS",
       "MSDM_embed_informed_traits_trained" = "M_TT",
       "MSDM_embed_informed_range_static" = "M_RS",
-      "MSDM_embed_informed_range_trained" = "M_RT"
+      "MSDM_embed_informed_range_trained" = "M_RT",
+      "MSDM_embed_informed_multi_nolonlat" = "M_MT_nolonlat",
+      "MSDM_embed_informed_multi_lonlat" = "M_MT_lonlat"
     ),
     across(all_of(focal_metrics), round, 3)
   )
@@ -274,13 +280,12 @@ for(model_name in unique(df_plot$model)){
 bslib::card(plot, full_screen = T)
 ```
 
-
 ## *Relative Performance*
 
 ```{r delta, echo = FALSE, message=FALSE, warnings=FALSE}
 results_ranked_obs = results_final_long %>% 
   group_by(species,  metric) %>% 
-  mutate(rank = rev(rank(value)))
+  mutate(rank = rank(value))
 
 reglines = results_ranked_obs %>%
   group_by(model, metric) %>%
@@ -426,7 +431,6 @@ bslib::card(plot, full_screen = T)
 ```
 :::
 
-
 ## *Trait space*
 
 ```{r trait_pca, echo = FALSE, message=FALSE, warnings=FALSE}
diff --git a/R/utils.R b/R/utils.R
index 1c9dc02ef9d965b3ed9a02f7552aec85de46141f..a7436a59c47617a8581a216a2306c42d8f5eb081 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -40,7 +40,7 @@ evaluate_model <- function(model, data) {
   
   # Predict probabilities
   if(class(model) %in% c("citodnn", "citodnnBootstrap")){
-    probs <- predict(model, as.matrix(data), type = "response")[,1]
+    probs <- predict(model, data, type = "response")[,1]
     preds <- factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
   } else {
     probs <- predict(model, data, type = "prob")$P