diff --git a/R/03_02_absence_preparation.R b/R/03_02_absence_preparation.R
index b148c644cd3c799537e98bd6365449038a2db921..03865cb12614e49b881a864410b068e7263f6db9 100644
--- a/R/03_02_absence_preparation.R
+++ b/R/03_02_absence_preparation.R
@@ -199,4 +199,4 @@ files = list.files("data/r_objects/pa_sampling/", full.names = T)
 model_data = lapply(files, function(f){load(f); return(pa_spec)}) %>% 
   bind_rows()
 
-save(model_data, file = "data/r_objects/model_data_pa_sampling.RData")
+save(model_data, file = "data/r_objects/model_data.RData")
diff --git a/R/04_01_ssdm_modeling.R b/R/04_01_ssdm_modeling.R
index 914c1aa2990cacea42993a6cf5775d6954dc51b3..661e169f88482ecc0555d822aed62c89f2d30bdb 100644
--- a/R/04_01_ssdm_modeling.R
+++ b/R/04_01_ssdm_modeling.R
@@ -1,208 +1,199 @@
 library(furrr)
-library(progressr)
 library(dplyr)
+library(tidyr)
+library(sf)
+library(caret)
 library(cito)
+library(pROC)
 
 source("R/utils.R")
 
+load("data/r_objects/model_data.RData")
+
 # ----------------------------------------------------------------------#
-# Define Training function                                           ####
+# Run training                                                       ####
 # ----------------------------------------------------------------------#
-train_models = function(pa_files){
-  future::plan("multisession", workers = 4)
-  p = progressor(along = pa_files)
-  
-  furrr::future_walk(
-    .x = pa_files, 
-    .options = furrr::furrr_options(
-      seed = 123,
-      packages = c("dplyr", "tidyr", "sf", "caret", "cito", "pROC"),
-      scheduling = FALSE  # make sure workers get constant stream of work
-    ),
-    .f = function(pa_file){
-      load(pa_file)
-      species = pa_spec$species[1]
-      p(sprintf("species=%s", species))
+data_split = model_data %>% 
+  dplyr::group_by(species) %>% 
+  dplyr::group_split()
+
+future::plan("multisession", workers = 16)
+
+furrr::future_walk(
+  .x = data_split, 
+  .options = furrr::furrr_options(
+    seed = 123,
+    packages = c("dplyr", "tidyr", "sf", "caret", "cito", "pROC"),
+    scheduling = FALSE  # make sure workers get constant stream of work
+  ),
+  .f = function(pa_spec){
+    species = pa_spec$species[1]
+    
+    if(all(is.na(pa_spec$fold_eval))){
+      warning("Too few samples")
+      return()
+    }
+    
+    # Define empty result for performance eval
+    na_performance = list(    
+      AUC = NA,
+      Accuracy = NA,
+      Kappa = NA,
+      Precision = NA,
+      Recall = NA,
+      F1 = NA
+    )
+    
+    # Create factor presence column 
+    pa_spec$present_fct = factor(pa_spec$present, levels = c("0", "1"), labels = c("A", "P"))
+    
+    # Outer CV loop (for averaging performance metrics)
+    performance_cv = lapply(sort(unique(pa_spec$fold_eval)), function(k){
+      data_test = dplyr::filter(pa_spec, fold_eval == k)
+      data_train = dplyr::filter(pa_spec, fold_eval != k)
       
-      if(all(is.na(pa_spec$fold_eval))){
+      if(nrow(data_test) == 0 || nrow(data_train) == 0){
         warning("Too few samples")
-        return()
       }
       
-      # Define empty result for performance eval
-      na_performance = list(    
-        AUC = NA,
-        Accuracy = NA,
-        Kappa = NA,
-        Precision = NA,
-        Recall = NA,
-        F1 = NA
+      # Create inner CV folds for model training
+      cv_train = blockCV::cv_spatial(
+        data_train,
+        column = "present",
+        k = 5,
+        progress = F, plot = F, report = F
       )
+      data_train$fold_train = cv_train$folds_ids
       
-      # Create factor presence column 
-      pa_spec$present_fct = factor(pa_spec$present, levels = c("0", "1"), labels = c("A", "P"))
+      # Drop geometry
+      data_train$geometry = NULL
+      data_test$geometry = NULL
       
-      # Outer CV loop (for averaging performance metrics)
-      performance_cv = lapply(sort(unique(pa_spec$fold_eval)), function(k){
-        data_test = dplyr::filter(pa_spec, fold_eval == k)
-        data_train = dplyr::filter(pa_spec, fold_eval != k)
-        
-        if(nrow(data_test) == 0 || nrow(data_train) == 0){
-          warning("Too few samples")
-        }
-        
-        # Create inner CV folds for model training
-        cv_train = blockCV::cv_spatial(
-          data_train,
-          column = "present",
-          k = 5,
-          progress = F, plot = F, report = F
+      # Define caret training routine #####
+      index_train = lapply(unique(sort(data_train$fold_train)), function(x){
+        return(which(data_train$fold_train != x))
+      })
+      
+      train_ctrl = caret::trainControl(
+        search = "random",
+        classProbs = TRUE, 
+        index = index_train,
+        summaryFunction = caret::twoClassSummary, 
+        savePredictions = "final",
+      )
+      
+      # Define predictors
+      predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness")
+      
+      # Random Forest #####
+      rf_performance = tryCatch({
+        # Fit model
+        rf_fit = caret::train(
+          x = data_train[, predictors],
+          y = data_train$present_fct,
+          method = "rf",
+          trControl = train_ctrl,
+          tuneLength = 4,
+          verbose = F
         )
-        data_train$fold_train = cv_train$folds_ids
-        
-        # Drop geometry
-        data_train$geometry = NULL
-        data_test$geometry = NULL
-        
-        # Define caret training routine #####
-        index_train = lapply(unique(sort(data_train$fold_train)), function(x){
-          return(which(data_train$fold_train != x))
-        })
         
-        train_ctrl = trainControl(
-          search = "random",
-          classProbs = TRUE, 
-          index = index_train,
-          summaryFunction = twoClassSummary, 
-          savePredictions = "final",
+        evaluate_model(rf_fit, data_test)
+      }, error = function(e){
+        na_performance
+      })
+      
+      # Gradient Boosted Machine ####
+      gbm_performance = tryCatch({
+        gbm_fit = train(
+          x = data_train[, predictors],
+          y = data_train$present_fct,
+          method = "gbm",
+          trControl = train_ctrl,
+          tuneLength = 4,
+          verbose = F
         )
+        evaluate_model(gbm_fit, data_test)
+      }, error = function(e){
+        na_performance
+      })
+      
+      # Generalized Additive Model ####
+      gam_performance = tryCatch({
+        gam_fit = train(
+          x = data_train[, predictors],
+          y = data_train$present_fct,
+          method = "gamSpline",
+          tuneLength = 4,
+          trControl = train_ctrl
+        )
+        evaluate_model(gam_fit, data_test)
+      }, error = function(e){
+        na_performance
+      })
+      
+      # Neural Network ####
+      nn_performance = tryCatch({
+        formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+')))
         
-        # Define predictors
-        predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness")
-        
-        # Random Forest #####
-        rf_performance = tryCatch({
-          # Fit model
-          rf_fit = caret::train(
-            x = data_train[, predictors],
-            y = data_train$present_fct,
-            method = "rf",
-            trControl = train_ctrl,
-            tuneLength = 4,
-            verbose = F
-          )
-          
-          evaluate_model(rf_fit, data_test)
-        }, error = function(e){
-          na_performance
-        })
-        
-        # Gradient Boosted Machine ####
-        gbm_performance = tryCatch({
-          gbm_fit = train(
-            x = data_train[, predictors],
-            y = data_train$present_fct,
-            method = "gbm",
-            trControl = train_ctrl,
-            tuneLength = 4,
-            verbose = F
-          )
-          evaluate_model(gbm_fit, data_test)
-        }, error = function(e){
-          na_performance
-        })
-        
-        # Generalized Additive Model ####
-        gam_performance = tryCatch({
-          gam_fit = train(
-            x = data_train[, predictors],
-            y = data_train$present_fct,
-            method = "gamSpline",
-            tuneLength = 4,
-            trControl = train_ctrl
-          )
-          evaluate_model(gam_fit, data_test)
-        }, error = function(e){
-          na_performance
-        })
+        nn_fit = dnn(
+          formula,
+          data = data_train,
+          hidden = c(100L, 100L, 100L),
+          loss = "binomial",
+          activation = c("leaky_relu", "leaky_relu", "leaky_relu"),
+          epochs = 500L, 
+          burnin = 100L,
+          lr = 0.001,   
+          batchsize = max(nrow(data_test)/10, 32),
+          lambda = 0.01,
+          dropout = 0.2,
+          optimizer = config_optimizer("adam", weight_decay = 0.001),
+          lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 50, factor = 0.7),
+          early_stopping = 100,
+          validation = 0.2,
+          device = "cuda",
+          verbose = F,
+          plot = F
+        )
         
-        # Neural Network ####
-        nn_performance = tryCatch({
-          formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+')))
-          
-          nn_fit = dnn(
-            formula,
-            data = data_train,
-            hidden = c(100L, 100L, 100L),
-            loss = "binomial",
-            activation = c("leaky_relu", "leaky_relu", "leaky_relu"),
-            epochs = 500L, 
-            burnin = 100L,
-            lr = 0.001,   
-            batchsize = max(nrow(data_train)/10, 64),
-            dropout = 0.1,
-            optimizer = config_optimizer("adam", weight_decay = 0.001),
-            lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 50, factor = 0.7),
-            early_stopping = 100,
-            validation = 0.2,
-            device = "cuda",
-            verbose = F,
-            plot = F
-          )
-          
-          if(nn_fit$successfull == 1){
-            evaluate_model(nn_fit, data_test)  
-          } else {
-            na_performance
-          }
-        }, error = function(e){
+        if(nn_fit$successfull == 1){
+          evaluate_model(nn_fit, data_test)  
+        } else {
           na_performance
-        })
-        
-        # Summarize results
-        performance_summary = tibble(
-          species = !!species,
-          obs = nrow(data_train),
-          fold_eval = k,
-          model = c("RF", "GBM", "GAM", "NN"),
-          auc = c(rf_performance$AUC, gbm_performance$AUC, gam_performance$AUC, nn_performance$AUC),
-          accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, gam_performance$Accuracy, nn_performance$Accuracy),
-          kappa = c(rf_performance$Kappa, gbm_performance$Kappa, gam_performance$Kappa, nn_performance$Kappa),
-          precision = c(rf_performance$Precision, gbm_performance$Precision, gam_performance$Precision, nn_performance$Precision),
-          recall = c(rf_performance$Recall, gbm_performance$Recall, gam_performance$Recall, nn_performance$Recall),
-          f1 = c(rf_performance$F1, gbm_performance$F1, gam_performance$F1, nn_performance$F1)
-        ) %>% 
-          pivot_longer(all_of(c("auc", "accuracy", "kappa", "precision", "recall", "f1")), names_to = "metric", values_to = "value")
-        
-        return(performance_summary)
+        }
+      }, error = function(e){
+        na_performance
       })
       
-      # Combine and save evaluation results
-      performance_spec = bind_rows(performance_cv)
-      save(performance_spec, file = paste0("data/r_objects/model_results/", species, ".RData"))
-    }
-  )
-}
-
-# ----------------------------------------------------------------------#
-# Run training                                                       ####
-# ----------------------------------------------------------------------#
-handlers(global = TRUE)
-handlers("progress")
-
-pa_files = list.files("data/r_objects/pa_sampling/", full.names = T)
-species_processed = stringr::str_remove_all(list.files("data/r_objects/model_results/"), ".RData")
-pa_files_run = pa_files[!sapply(pa_files, function(x){any(stringr::str_detect(x, species_processed))})]
-
-train_models(pa_files_run)
+      # Summarize results
+      performance_summary = tibble(
+        species = !!species,
+        obs = nrow(data_train),
+        fold_eval = k,
+        model = c("RF", "GBM", "GAM", "NN"),
+        auc = c(rf_performance$AUC, gbm_performance$AUC, gam_performance$AUC, nn_performance$AUC),
+        accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, gam_performance$Accuracy, nn_performance$Accuracy),
+        kappa = c(rf_performance$Kappa, gbm_performance$Kappa, gam_performance$Kappa, nn_performance$Kappa),
+        precision = c(rf_performance$Precision, gbm_performance$Precision, gam_performance$Precision, nn_performance$Precision),
+        recall = c(rf_performance$Recall, gbm_performance$Recall, gam_performance$Recall, nn_performance$Recall),
+        f1 = c(rf_performance$F1, gbm_performance$F1, gam_performance$F1, nn_performance$F1)
+      ) %>% 
+        tidyr::pivot_longer(all_of(c("auc", "accuracy", "kappa", "precision", "recall", "f1")), names_to = "metric", values_to = "value")
+      
+      return(performance_summary)
+    })
+    
+    # Combine and save evaluation results
+    performance_spec = bind_rows(performance_cv)
+    save(performance_spec, file = paste0("data/r_objects/ssdm_results/", species, ".RData"))
+  }
+)
 
 # ----------------------------------------------------------------------#
 # Combine results                                                    ####
 # ----------------------------------------------------------------------#
-files = list.files("data/r_objects/model_results/", full.names = T)
+files = list.files("data/r_objects/ssdm_results/", full.names = T)
 ssdm_results = lapply(files, function(f){load(f); return(performance_spec)}) %>% 
   bind_rows() 
-  #dplyr::group_by(species, model, metric) %>% 
-  #dplyr::summarize(value = mean(value, na.rm =T))
 
 save(ssdm_results, file = "data/r_objects/ssdm_results.RData")
diff --git a/R/04_03_msdm_embed_raw.R b/R/04_03_msdm_embed_raw.R
index 6b2df020f31d9ad7f1a666bebd1e6675ff230a9f..0775329909439d7f3f6292fe38e02a901b3f0241 100644
--- a/R/04_03_msdm_embed_raw.R
+++ b/R/04_03_msdm_embed_raw.R
@@ -10,10 +10,10 @@ model_data = model_data %>%
   dplyr::mutate(species_int = as.integer(as.factor(model_data$species))) %>% 
   sf::st_drop_geometry()
   
-test_data = dplyr::filter(model_data, fold_eval == 1) %>% 
-  dplyr::select(-fold_eval)
-train_data = dplyr::filter(model_data, fold_eval != 1) %>% 
-  dplyr::select(-fold_eval)
+fold = 1
+
+test_data = dplyr::filter(model_data, fold_eval == fold)
+train_data = dplyr::filter(model_data, fold_eval != fold)
 
 # ----------------------------------------------------------------------#
 # Train model                                                        ####
@@ -25,15 +25,16 @@ plot(1, type="n", xlab="", ylab="", xlim=c(0, 25000), ylim=c(0, 0.7)) # empty pl
 msdm_fit_embedding_raw = dnn(
   formula,
   data = train_data,
-  hidden = c(500L, 500L, 500L),
+  hidden = c(250L, 250L, 250L),
   loss = "binomial",
   activation = c("sigmoid", "leaky_relu", "leaky_relu"),
   epochs = 10L, 
-  lr = 0.01,   
+  lr = 0.1,   
   baseloss = 1,
-  batchsize = nrow(train_data),
+  batchsize = 4096,
   dropout = 0.1,
   burnin = 500,
+  lambda = 0.0001,
   optimizer = config_optimizer("adam", weight_decay = 0.001),
   lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 150, factor = 0.7),
   early_stopping = 400,
@@ -41,7 +42,7 @@ msdm_fit_embedding_raw = dnn(
   device = "cuda"
 )
 
-save(msdm_fit_embedding_raw, file = "data/r_objects/msdm_fit_embedding_raw.RData")
+save(msdm_fit_embedding_raw, file = paste0("data/r_objects/msdm_raw_results/msdm_raw_fold", fold,".RData"))
 
 # ----------------------------------------------------------------------#
 # Evaluate model                                                     ####
@@ -51,28 +52,30 @@ load("data/r_objects/msdm_fit_embedding_raw.RData")
 data_split = test_data %>% 
   split(test_data$species)
 
-msdm_results_embedding_raw = lapply(data_split, function(data_spec){
-  species = data_spec$species[1]
-  data_spec = data_spec %>% 
-    dplyr::select(-species)
+msdm_results_embedding_raw = lapply(data_split, function(pa_spec){
+  species = pa_spec$species[1]
+  pa_spec = pa_spec %>% 
+    dplyr::select(-species, -fold_eval)
   
   msdm_performance = tryCatch({
-    evaluate_model(msdm_fit_embedding_raw, data_spec)
+    evaluate_model(msdm_fit_embedding_raw, pa_spec)
   }, error = function(e){
     list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
   })
   
   performance_summary = tibble(
     species = !!species,
-    obs = nrow(data_spec),
-    model = "MSDM_embed",
+    obs = nrow(dplyr::filter(train_data, species == !!species)),
+    fold_eval = !!fold,
+    model = "MSDM_embed_raw",
     auc = msdm_performance$AUC,
     accuracy = msdm_performance$Accuracy,
     kappa = msdm_performance$Kappa,
     precision = msdm_performance$Precision,
     recall = msdm_performance$Recall,
     f1 = msdm_performance$F1
-  )
+  ) %>% 
+    tidyr::pivot_longer(all_of(c("auc", "accuracy", "kappa", "precision", "recall", "f1")), names_to = "metric", values_to = "value")
 }) %>% bind_rows()
 
-save(msdm_results_embedding_raw, file = "data/r_objects/msdm_results_embedding_raw.RData")
+save(msdm_results_embedding_raw, file = paste0("data/r_objects/msdm_raw_results/result_msdm_raw_fold", fold,".RData"))
diff --git a/R/utils.R b/R/utils.R
index 1c9dc02ef9d965b3ed9a02f7552aec85de46141f..265feed61ea02923661fa8b0225fa5c39b35e929 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -40,17 +40,18 @@ evaluate_model <- function(model, data) {
   
   # Predict probabilities
   if(class(model) %in% c("citodnn", "citodnnBootstrap")){
-    probs <- predict(model, as.matrix(data), type = "response")[,1]
-    preds <- factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
+    data = dplyr::select(data, any_of(all.vars(msdm_fit_embedding_raw$old_formula)))
+    probs = predict(model, as.matrix(data), type = "response")[,1]
+    preds = factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
   } else {
-    probs <- predict(model, data, type = "prob")$P
-    preds <- predict(model, data, type = "raw")
+    probs = predict(model, data, type = "prob")$P
+    preds = predict(model, data, type = "raw")
   }
   
-  actual <- factor(data$present, levels = c("0", "1"), labels = c("A", "P"))
+  actual = factor(data$present, levels = c("0", "1"), labels = c("A", "P"))
   
   # Calculate AUC
-  auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
+  auc = pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
   
   # Calculate confusion matrix
   cm <- caret::confusionMatrix(preds, actual, positive = "P")