From e0cabef305860b292268e3b3594ed483b0505b7e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=B6nig?= <ye87zine@usr.idiv.de>
Date: Mon, 17 Feb 2025 12:53:30 +0100
Subject: [PATCH] updated/implemented different msdm approaches

---
 R/04_01_modelling_ssdm.R                | 138 ++++++-------
 R/04_02_modelling_msdm_embed.R          |   8 +-
 R/04_03_modelling_msdm_onehot.R         | 105 ++++++++++
 R/04_04_modelling_msdm_embed_informed.R | 115 +++++++++++
 R/04_05_modelling_msdm_rf.R             | 150 ++++++++++++++
 R/05_01_performance_report.qmd          | 103 +++++++---
 R/05_02_publication_analysis.R          | 251 +++++++++++++++++-------
 R/_publish.yml                          |   4 +
 snippets.R                              | 160 ---------------
 9 files changed, 698 insertions(+), 336 deletions(-)
 create mode 100644 R/04_03_modelling_msdm_onehot.R
 create mode 100644 R/04_04_modelling_msdm_embed_informed.R
 create mode 100644 R/04_05_modelling_msdm_rf.R
 delete mode 100644 snippets.R

diff --git a/R/04_01_modelling_ssdm.R b/R/04_01_modelling_ssdm.R
index c336fe3..e5240db 100644
--- a/R/04_01_modelling_ssdm.R
+++ b/R/04_01_modelling_ssdm.R
@@ -1,4 +1,3 @@
-library(furrr)
 library(dplyr)
 library(tidyr)
 library(sf)
@@ -13,48 +12,33 @@ load("data/r_objects/model_data.RData")
 # ----------------------------------------------------------------------#
 # Run training                                                       ####
 # ----------------------------------------------------------------------#
-species_processed = list.files("data/r_objects/ssdm_results/", pattern = "RData") %>% 
+species_processed = list.files("data/r_objects/ssdm_results/performance/", pattern = "RData") %>% 
   stringr::str_remove(".RData")
 
 data_split = model_data %>% 
-  dplyr::filter(!species %in% species_processed) %>% 
+  dplyr::filter(!is.na(fold_eval) & !species %in% species_processed) %>% 
   dplyr::group_by(species) %>% 
   dplyr::group_split()
 
-
 for(pa_spec in data_split){
   species = pa_spec$species[1]
   print(species)
-  
-  if(all(is.na(pa_spec$fold_eval))){
-    print("Too few samples")
-    next
-  }
-  
+
   # Define empty result for performance eval
-  na_performance = list(    
-    AUC = NA,
-    Accuracy = NA,
-    Kappa = NA,
-    Precision = NA,
-    Recall = NA,
-    F1 = NA
-  )
+  na_performance = list(AUC = NA_real_, Accuracy = NA_real_, Kappa = NA_real_, 
+                        Precision = NA_real_, Recall = NA_real_, F1 = NA_real_, 
+                        TP = NA_real_, FP = NA_real_, TN = NA_real_, FN = NA_real_)
   
   # Create factor presence column 
   pa_spec$present_fct = factor(pa_spec$present, levels = c("0", "1"), labels = c("A", "P"))
   
   # Outer CV loop (for averaging performance metrics)
-  performance_cv = lapply(sort(unique(pa_spec$fold_eval)), function(k){
-    print(paste("Fold", k))
+  performance_cv = lapply(1:5, function(fold){
+    print(paste("Fold", fold))
     
     ## Preparations #####
-    data_test = dplyr::filter(pa_spec, fold_eval == k)
-    data_train = dplyr::filter(pa_spec, fold_eval != k)
-    
-    if(nrow(data_test) == 0 || nrow(data_train) == 0){
-      warning("Too few samples")
-    }
+    data_test = dplyr::filter(pa_spec, fold_eval == fold)
+    data_train = dplyr::filter(pa_spec, fold_eval != fold)
     
     # Create inner CV folds for model training
     cv_train = blockCV::cv_spatial(
@@ -79,7 +63,7 @@ for(pa_spec in data_split){
       classProbs = TRUE, 
       index = index_train,
       summaryFunction = caret::twoClassSummary, 
-      savePredictions = "final",
+      savePredictions = "final"
     )
     
     # Define predictors
@@ -93,7 +77,7 @@ for(pa_spec in data_split){
         y = data_train$present_fct,
         method = "rf",
         trControl = train_ctrl,
-        tuneLength = 4,
+        tuneLength = 8,
         verbose = F
       )
       
@@ -103,16 +87,16 @@ for(pa_spec in data_split){
     })
     
     ## Gradient Boosted Machine ####
-    gbm_performance = tryCatch({
-      gbm_fit = train(
+    xgb_performance = tryCatch({
+      xgb_fit = train(
         x = data_train[, predictors],
         y = data_train$present_fct,
-        method = "gbm",
+        method = "xgbTree",
         trControl = train_ctrl,
-        tuneLength = 4,
+        tuneLength = 8,
         verbose = F
       )
-      evaluate_model(gbm_fit, data_test)
+      evaluate_model(xgb_fit, data_test)
     }, error = function(e){
       na_performance
     })
@@ -123,7 +107,7 @@ for(pa_spec in data_split){
         x = data_train[, predictors],
         y = data_train$present_fct,
         method = "gamSpline",
-        tuneLength = 4,
+        tuneLength = 8,
         trControl = train_ctrl
       )
       evaluate_model(gam_fit, data_test)
@@ -139,16 +123,13 @@ for(pa_spec in data_split){
         data = data_train,
         hidden = c(200L, 200L, 200L),
         loss = "binomial",
-        activation = c("sigmoid", "leaky_relu", "leaky_relu"),
         epochs = 200L, 
-        burnin = 100L,
+        burnin = 200L,
+        early_stopping = 30,
         lr = 0.001,   
-        batchsize = max(nrow(data_test)/10, 64),
-        lambda = 0.0001,
-        dropout = 0.2,
-        optimizer = config_optimizer("adam", weight_decay = 0.001),
-        lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 50, factor = 0.7),
-        early_stopping = 100,
+        batchsize = min(ceiling(nrow(data_train)/10), 64),
+        dropout = 0.25,
+        optimizer = config_optimizer("adam"),
         validation = 0.2,
         device = "cuda",
         verbose = F,
@@ -164,16 +145,20 @@ for(pa_spec in data_split){
     performance_summary = tibble(
       species = !!species,
       obs = nrow(data_train),
-      fold_eval = k,
-      model = c("RF", "GBM", "GAM", "NN"),
-      auc = c(rf_performance$AUC, gbm_performance$AUC, gam_performance$AUC, nn_performance$AUC),
-      accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, gam_performance$Accuracy, nn_performance$Accuracy),
-      kappa = c(rf_performance$Kappa, gbm_performance$Kappa, gam_performance$Kappa, nn_performance$Kappa),
-      precision = c(rf_performance$Precision, gbm_performance$Precision, gam_performance$Precision, nn_performance$Precision),
-      recall = c(rf_performance$Recall, gbm_performance$Recall, gam_performance$Recall, nn_performance$Recall),
-      f1 = c(rf_performance$F1, gbm_performance$F1, gam_performance$F1, nn_performance$F1)
+      fold_eval = fold,
+      model = c("RF", "XGB", "GAM", "NN"),
+      AUC = c(rf_performance$AUC, xgb_performance$AUC, gam_performance$AUC, nn_performance$AUC),
+      Accuracy = c(rf_performance$Accuracy, xgb_performance$Accuracy, gam_performance$Accuracy, nn_performance$Accuracy),
+      Kappa = c(rf_performance$Kappa, xgb_performance$Kappa, gam_performance$Kappa, nn_performance$Kappa),
+      Precision = c(rf_performance$Precision, xgb_performance$Precision, gam_performance$Precision, nn_performance$Precision),
+      Recall = c(rf_performance$Recall, xgb_performance$Recall, gam_performance$Recall, nn_performance$Recall),
+      F1 = c(rf_performance$F1, xgb_performance$F1, gam_performance$F1, nn_performance$F1),
+      TP = c(rf_performance$TP, xgb_performance$TP, gam_performance$TP, nn_performance$TP),
+      FP = c(rf_performance$FP, xgb_performance$FP, gam_performance$FP, nn_performance$FP),
+      TN = c(rf_performance$TN, xgb_performance$TN, gam_performance$TN, nn_performance$TN),
+      FN = c(rf_performance$FN, xgb_performance$FN, gam_performance$FN, nn_performance$FN)
     ) %>% 
-      tidyr::pivot_longer(all_of(c("auc", "accuracy", "kappa", "precision", "recall", "f1")), names_to = "metric", values_to = "value") 
+      tidyr::pivot_longer(-any_of(c("species", "obs", "fold_eval", "model")), names_to = "metric", values_to = "value") 
     
     return(performance_summary)
   })
@@ -182,23 +167,22 @@ for(pa_spec in data_split){
   performance_spec = bind_rows(performance_cv) %>% 
     dplyr::arrange(fold_eval, model, metric)
   
-  save(performance_spec, file = paste0("data/r_objects/ssdm_results/", species, ".RData"))
+  save(performance_spec, file = paste0("data/r_objects/ssdm_results/performance/", species, ".RData"))
 }
 
-
 # Combine results  
-files = list.files("data/r_objects/ssdm_results/", full.names = T, pattern = ".RData")
+files = list.files("data/r_objects/ssdm_results/performance/", full.names = T, pattern = ".RData")
 ssdm_results = lapply(files, function(f){load(f); return(performance_spec)}) %>% 
   bind_rows() 
 
-save(ssdm_results, file = "data/r_objects/ssdm_results.RData")
+save(ssdm_results, file = "data/r_objects/ssdm_performance.RData")
 
 # ----------------------------------------------------------------------#
 # Train full models                                                  ####
 # ----------------------------------------------------------------------#
 data_split = model_data %>% 
   dplyr::filter(!is.na(fold_eval)) %>% 
-  dplyr::mutate(present_fct = factor(present, levels = c("0", "1"), labels = c("A", "P"))) %>% 
+  dplyr::mutate(present_fct = factor(present, levels = c("0", "1"), labels = c("A", "P"))) %>%
   dplyr::group_by(species) %>% 
   dplyr::group_split()
 
@@ -217,7 +201,6 @@ for(pa_spec in data_split){
   
   # Drop geometry
   pa_spec$geometry = NULL
-  pa_spec$geometry = NULL
   
   # Define caret training routine 
   index_train = lapply(unique(sort(pa_spec$fold_train)), function(x){
@@ -231,33 +214,37 @@ for(pa_spec in data_split){
     summaryFunction = caret::twoClassSummary, 
     savePredictions = "final",
   )
-  
+
   # Define predictors
   predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness")
   
   # Fit models
   try({
+    # Fit model
     rf_fit = caret::train(
       x = pa_spec[, predictors],
       y = pa_spec$present_fct,
       method = "rf",
+      metric = "Kappa",
       trControl = train_ctrl,
-      tuneLength = 4,
+      tuneLength = 8,
       verbose = F
     )
-    save(rf_fit, file = paste0("data/r_objects/ssdm_results/models/", species, "_rf_fit.RData"))
+
+    save(rf_fit, file = paste0("data/r_objects/ssdm_results/full_models/", species, "_rf_fit.RData"))
   })
   
   try({
-    gbm_fit = train(
+    xgb_fit = train(
       x = pa_spec[, predictors],
       y = pa_spec$present_fct,
-      method = "gbm",
+      method = "xgbTree",
+      metric = "Kappa",
       trControl = train_ctrl,
-      tuneLength = 4,
+      tuneLength = 8,
       verbose = F
     )
-    save(gbm_fit, file = paste0("data/r_objects/ssdm_results/models/", species, "_gbm_fit.RData"))
+    save(xgb_fit, file = paste0("data/r_objects/ssdm_results/full_models/", species, "_xgb_fit.RData"))
   })
   
   try({
@@ -265,10 +252,11 @@ for(pa_spec in data_split){
       x = pa_spec[, predictors],
       y = pa_spec$present_fct,
       method = "gamSpline",
-      tuneLength = 4,
-      trControl = train_ctrl
+      metric = "Kappa",
+      trControl = train_ctrl,
+      tuneLength = 8,
     )
-    save(gam_fit, file = paste0("data/r_objects/ssdm_results/models/", species, "_gam_fit.RData"))
+    save(gam_fit, file = paste0("data/r_objects/ssdm_results/full_models/", species, "_gam_fit.RData"))
   })
   
   try({
@@ -278,21 +266,19 @@ for(pa_spec in data_split){
       data = pa_spec,
       hidden = c(200L, 200L, 200L),
       loss = "binomial",
-      activation = c("sigmoid", "leaky_relu", "leaky_relu"),
       epochs = 200L, 
-      burnin = 100L,
+      burnin = 200L,
+      early_stopping = 30,
       lr = 0.001,   
-      batchsize = max(nrow(pa_spec)/10, 64),
-      lambda = 0.0001,
-      dropout = 0.2,
-      optimizer = config_optimizer("adam", weight_decay = 0.001),
-      lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 50, factor = 0.7),
-      early_stopping = 100,
+      batchsize = min(ceiling(nrow(pa_spec)/10), 64),
+      dropout = 0.25,
+      optimizer = config_optimizer("adam"),
       validation = 0.2,
       device = "cuda",
       verbose = F,
       plot = F
     )
-    save(nn_fit, file = paste0("data/r_objects/ssdm_results/models/", species, "_nn_fit.RData"))
+    
+    save(nn_fit, file = paste0("data/r_objects/ssdm_results/full_models/", species, "_nn_fit.RData"))
   })
 }
diff --git a/R/04_02_modelling_msdm_embed.R b/R/04_02_modelling_msdm_embed.R
index dd991f5..584fb57 100644
--- a/R/04_02_modelling_msdm_embed.R
+++ b/R/04_02_modelling_msdm_embed.R
@@ -27,10 +27,10 @@ for(fold in 1:5){
   while(!converged){
     msdm_embed_fit = dnn(
       formula,
-      data = model_data,
+      data = train_data,
       hidden = c(200L, 200L, 200L),
       loss = "binomial",
-      epochs = 2500, 
+      epochs = 5000, 
       lr = 0.001,   
       batchsize = 4096,
       dropout = 0.25,
@@ -41,7 +41,7 @@ for(fold in 1:5){
       device = "cuda"
     )
     
-    if(min(msdm_embed_fit$losses$valid_l) < 0.4){
+    if(min(msdm_embed_fit$losses$valid_l, na.rm = T) < 0.4){
       converged = T
     }
   }
@@ -55,7 +55,7 @@ msdm_embed_fit = dnn(
   data = model_data,
   hidden = c(200L, 200L, 200L),
   loss = "binomial",
-  epochs = 2500, 
+  epochs = 7500, 
   lr = 0.001,   
   baseloss = 1,
   batchsize = 4096,
diff --git a/R/04_03_modelling_msdm_onehot.R b/R/04_03_modelling_msdm_onehot.R
new file mode 100644
index 0000000..71f5ac0
--- /dev/null
+++ b/R/04_03_modelling_msdm_onehot.R
@@ -0,0 +1,105 @@
+library(dplyr)
+library(tidyr)
+library(cito)
+
+source("R/utils.R")
+
+load("data/r_objects/model_data.RData")
+
+model_data = model_data %>% 
+  dplyr::filter(!is.na(fold_eval)) %>% 
+  dplyr::mutate(species = as.factor(species)) %>% 
+  sf::st_drop_geometry()
+
+# ----------------------------------------------------------------------#
+# Train model                                                        ####
+# ----------------------------------------------------------------------#
+formula = present ~ bio6 + bio17 + cmi + rsds + igfc + dtfw + igsw + roughness + species
+
+# 1. Cross validation
+for(fold in 1:5){
+  # Prepare data
+  train_data = dplyr::filter(model_data, fold_eval != fold)
+  
+  # Run model
+  converged = F
+  while(!converged){
+    msdm_onehot_fit = dnn(
+      formula,
+      data = train_data,
+      hidden = c(200L, 200L, 200L),
+      loss = "binomial",
+      epochs = 5000, 
+      lr = 0.001,   
+      batchsize = 4096,
+      dropout = 0.25,
+      burnin = 50,
+      optimizer = config_optimizer("adam"),
+      early_stopping = 200,
+      validation = 0.2,
+      device = "cuda"
+    )
+    
+    if(min(msdm_onehot_fit$losses$valid_l, na.rm = T) < 0.4){
+      converged = T
+    }
+  }
+  
+  save(msdm_onehot_fit, file = paste0("data/r_objects/msdm_onehot_results/msdm_onehot_fit_fold", fold,".RData"))
+}
+
+# Full model
+msdm_onehot_fit = dnn(
+  formula,
+  data = model_data,
+  hidden = c(200L, 200L, 200L),
+  loss = "binomial",
+  epochs = 7500, 
+  lr = 0.001,   
+  batchsize = 4096,
+  dropout = 0.25,
+  burnin = 500,
+  optimizer = config_optimizer("adam"),
+  early_stopping = 300,
+  validation = 0.2,
+  device = "cuda"
+)
+
+save(msdm_onehot_fit, file = paste0("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData"))
+
+# ----------------------------------------------------------------------#
+# Evaluate model                                                     ####
+# ----------------------------------------------------------------------#
+msdm_onehot_performance = lapply(1:5, function(fold){
+  load(paste0("data/r_objects/msdm_onehot_results/msdm_onehot_fit_fold", fold, ".RData"))
+  
+  test_data_split = model_data %>% 
+    dplyr::filter(fold_eval == fold) %>% 
+    dplyr::group_split(species)
+  
+  lapply(test_data_split, function(test_data_spec){
+    species = test_data_spec$species[1]
+    
+    performance = tryCatch({
+      evaluate_model(msdm_onehot_fit, test_data_spec)
+    }, error = function(e){
+      list(AUC = NA_real_, Accuracy = NA_real_, Kappa = NA_real_, 
+           Precision = NA_real_, Recall = NA_real_, F1 = NA_real_, 
+           TP = NA_real_, FP = NA_real_, TN = NA_real_, FN = NA_real_)
+    })
+    
+    performance_summary = performance %>% 
+      as_tibble() %>% 
+      mutate(
+        species = !!species,
+        obs = nrow(dplyr::filter(model_data, species == !!species, fold_eval != !!fold)),
+        fold_eval = !!fold,
+        model = "MSDM_onehot",
+      ) %>% 
+      tidyr::pivot_longer(-any_of(c("species", "obs", "fold_eval", "model")), names_to = "metric", values_to = "value")
+  }) %>% 
+    bind_rows()
+}) %>% 
+  bind_rows()
+
+save(msdm_onehot_performance, file = paste0("data/r_objects/msdm_onehot_performance.RData"))
diff --git a/R/04_04_modelling_msdm_embed_informed.R b/R/04_04_modelling_msdm_embed_informed.R
new file mode 100644
index 0000000..719f2ec
--- /dev/null
+++ b/R/04_04_modelling_msdm_embed_informed.R
@@ -0,0 +1,115 @@
+library(dplyr)
+library(tidyr)
+library(cito)
+
+source("R/utils.R")
+
+load("data/r_objects/model_data.RData")
+load("data/r_objects/func_dist.RData")
+
+model_species = intersect(model_data$species, colnames(func_dist)) 
+
+model_data = model_data %>% 
+  dplyr::filter(
+    !is.na(fold_eval),
+    species %in% !!model_species
+  ) %>% 
+  dplyr::mutate(species = as.factor(species)) %>% 
+  sf::st_drop_geometry()
+
+# ----------------------------------------------------------------------#
+# Train model                                                        ####
+# ----------------------------------------------------------------------#
+func_dist = func_dist[model_species, model_species]
+embeddings = eigen(func_dist)$vectors[,1:10]
+predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness")
+formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species, weights = embeddings)")) 
+
+# 1. Cross validation
+for(fold in 1:5){
+  # Prepare data
+  train_data = dplyr::filter(model_data, fold_eval != fold)
+  
+  # Run model
+  converged = F
+  while(!converged){
+    msdm_embed_traits_fit = dnn(
+      formula,
+      data = train_data,
+      hidden = c(200L, 200L, 200L),
+      loss = "binomial",
+      epochs = 5000, 
+      lr = 0.001,   
+      batchsize = 4096,
+      dropout = 0.25,
+      burnin = 50,
+      optimizer = config_optimizer("adam"),
+      early_stopping = 200,
+      validation = 0.2,
+      device = "cuda"
+    )
+    
+    if(min(msdm_embed_traits_fit$losses$valid_l, na.rm = T) < 0.4){
+      converged = T
+    }
+  }
+  
+  save(msdm_embed_traits_fit, file = paste0("data/r_objects/msdm_embed_traits_results/msdm_embed_traits_fit_fold", fold,".RData"))
+}
+
+# Full model
+msdm_embed_traits_fit = dnn(
+  formula,
+  data = model_data,
+  hidden = c(200L, 200L, 200L),
+  loss = "binomial",
+  epochs = 7500, 
+  lr = 0.001,   
+  baseloss = 1,
+  batchsize = 4096,
+  dropout = 0.25,
+  burnin = 500,
+  optimizer = config_optimizer("adam"),
+  early_stopping = 300,
+  validation = 0.2,
+  device = "cuda"
+)
+
+save(msdm_embed_traits_fit, file = paste0("data/r_objects/msdm_embed_traits_results/msdm_embed_traits_fit_full.RData"))
+
+# ----------------------------------------------------------------------#
+# Evaluate model                                                     ####
+# ----------------------------------------------------------------------#
+msdm_embed_traits_performance = lapply(1:5, function(fold){
+  load(paste0("data/r_objects/msdm_embed_traits_results/msdm_embed_traits_fit_fold", fold, ".RData"))
+  
+  test_data_split = model_data %>% 
+    dplyr::filter(fold_eval == fold) %>% 
+    dplyr::group_split(species)
+  
+  lapply(test_data_split, function(test_data_spec){
+    species = test_data_spec$species[1]
+    
+    performance = tryCatch({
+      evaluate_model(msdm_embed_traits_fit, test_data_spec)
+    }, error = function(e){
+      list(AUC = NA_real_, Accuracy = NA_real_, Kappa = NA_real_, 
+           Precision = NA_real_, Recall = NA_real_, F1 = NA_real_, 
+           TP = NA_real_, FP = NA_real_, TN = NA_real_, FN = NA_real_)
+    })
+    
+    performance_summary = performance %>% 
+      as_tibble() %>% 
+      mutate(
+        species = !!species,
+        obs = nrow(dplyr::filter(model_data, species == !!species, fold_eval != !!fold)),
+        fold_eval = !!fold,
+        model = "MSDM_embed_traits",
+      ) %>% 
+      tidyr::pivot_longer(-any_of(c("species", "obs", "fold_eval", "model")), names_to = "metric", values_to = "value")
+  }) %>% 
+    bind_rows()
+}) %>% 
+  bind_rows()
+
+save(msdm_embed_traits_performance, file = paste0("data/r_objects/msdm_embed_traits_performance.RData"))
diff --git a/R/04_05_modelling_msdm_rf.R b/R/04_05_modelling_msdm_rf.R
new file mode 100644
index 0000000..29819aa
--- /dev/null
+++ b/R/04_05_modelling_msdm_rf.R
@@ -0,0 +1,150 @@
+library(dplyr)
+library(tidyr)
+library(cito)
+
+source("R/utils.R")
+
+load("data/r_objects/model_data.RData")
+
+model_data = model_data %>% 
+  dplyr::filter(!is.na(fold_eval)) %>% 
+  dplyr::mutate(
+    species = as.factor(species),
+    present_fct = factor(present, levels = c("0", "1"), labels = c("A", "P"))
+  ) 
+
+# ----------------------------------------------------------------------#
+# Train model                                                        ####
+# ----------------------------------------------------------------------#
+# Define predictors
+predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness", "species")
+
+# Cross validation
+for(fold in 1:5){
+  ## Preparations #####
+  data_train = dplyr::filter(model_data, fold_eval != fold) %>% 
+    sf::st_drop_geometry()
+  
+  # Define caret training routine 
+  train_ctrl = caret::trainControl(
+    method = "cv",
+    number = 5,
+    classProbs = TRUE, 
+    summaryFunction = caret::twoClassSummary, 
+    savePredictions = "final"
+  )
+  
+  tune_grid = expand.grid(
+    mtry = c(2,4,6,8),
+    splitrule = "gini",
+    min.node.size = c(1,4,9,16)
+  )
+  
+  # Run model
+  rf_fit = caret::train(
+    x = data_train[, predictors],
+    y = data_train$present_fct,
+    method = "ranger",
+    metric = "Accuracy",
+    trControl = train_ctrl,
+    tuneGrid = tune_grid,
+    num.threads = 32
+  )
+  
+  save(rf_fit, file = paste0("data/r_objects/msdm_rf/msdm_rf_fit_fold", fold,".RData"))
+}
+
+# Full model
+# Define caret training routine 
+full_data = model_data %>% 
+  sf::st_drop_geometry()
+
+train_ctrl = caret::trainControl(
+  method = "cv",
+  number = 5,
+  classProbs = TRUE, 
+  summaryFunction = caret::twoClassSummary, 
+  savePredictions = "final"
+)
+
+tune_grid = expand.grid(
+  mtry = c(2,4,6,8),
+  splitrule = "gini",
+  min.node.size = c(1,4,9,16)
+)
+
+# Run model
+rf_fit = caret::train(
+  x = full_data[, predictors],
+  y = full_data$present_fct,
+  method = "ranger",
+  metric = "Accuracy",
+  trControl = train_ctrl,
+  tuneGrid = tune_grid
+)
+
+save(rf_fit, file = "data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
+
+# ----------------------------------------------------------------------#
+# Evaluate model                                                     ####
+# ----------------------------------------------------------------------#
+msdm_rf_performance = lapply(1:5, function(fold){
+  load(paste0("data/r_objects/msdm_rf/rf_fit_fold", fold, ".RData"))
+  
+  test_data =dplyr::filter(model_data, fold_eval == fold) %>% 
+    sf::st_drop_geometry()
+  
+  actual = factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
+  probs = predict(rf_fit, test_data, type = "prob")$P
+  preds = predict(rf_fit, test_data, type = "raw")
+  
+  eval_dfs = data.frame(
+    species = test_data$species,
+    actual,
+    probs,
+    preds
+  ) %>% 
+    group_by(species) %>% 
+    group_split()
+  
+  
+  lapply(eval_dfs, function(eval_df_spec){
+    species = eval_df_spec$species[1]
+    
+    performance = tryCatch({
+      auc = pROC::roc(eval_df_spec$actual, eval_df_spec$probs, levels = c("P", "A"), direction = ">")$auc
+      cm = caret::confusionMatrix(eval_df_spec$preds, eval_df_spec$actual, positive = "P")
+      
+      list(
+        AUC = as.numeric(auc),
+        Accuracy = cm$overall["Accuracy"],
+        Kappa = cm$overall["Kappa"],
+        Precision = cm$byClass["Precision"],
+        Recall = cm$byClass["Recall"],
+        F1 = cm$byClass["F1"],
+        TP = cm$table["P", "P"],
+        FP = cm$table["P", "A"],
+        TN = cm$table["A", "A"],
+        FN = cm$table["A", "P"]
+      )
+    }, error = function(e){
+      list(AUC = NA_real_, Accuracy = NA_real_, Kappa = NA_real_, 
+           Precision = NA_real_, Recall = NA_real_, F1 = NA_real_, 
+           TP = NA_real_, FP = NA_real_, TN = NA_real_, FN = NA_real_)
+    })
+    
+    performance_summary = performance %>% 
+      as_tibble() %>% 
+      mutate(
+        species = !!species,
+        obs = nrow(dplyr::filter(model_data, species == !!species, fold_eval != !!fold)),
+        fold_eval = !!fold,
+        model = "MSDM_rf",
+      ) %>% 
+      tidyr::pivot_longer(-any_of(c("species", "obs", "fold_eval", "model")), names_to = "metric", values_to = "value")
+  }) %>% 
+    bind_rows()
+}) %>% 
+  bind_rows()
+
+save(msdm_rf_performance, file = paste0("data/r_objects/msdm_rf_performance.RData"))
diff --git a/R/05_01_performance_report.qmd b/R/05_01_performance_report.qmd
index 03ed7d3..686a903 100644
--- a/R/05_01_performance_report.qmd
+++ b/R/05_01_performance_report.qmd
@@ -1,5 +1,5 @@
 ---
-title: "SDM Performance analysis"
+title: "SDM Performance report"
 format: html
 editor: visual
 engine: knitr
@@ -14,6 +14,8 @@ library(DT)
 load("../data/r_objects/model_data.RData")
 load("../data/r_objects/ssdm_results.RData")
 load("../data/r_objects/msdm_embed_performance.RData")
+load("../data/r_objects/msdm_onehot_performance.RData")
+load("../data/r_objects/msdm_rf_performance.RData")
 
 load("../data/r_objects/range_maps.RData")
 load("../data/r_objects/range_maps_gridded.RData")
@@ -54,6 +56,9 @@ lin_regression = function(x, y, family = "binomial"){
 # Performance table
 performance = ssdm_results %>% 
   bind_rows(msdm_embed_performance) %>% 
+  bind_rows(msdm_onehot_performance) %>% 
+  bind_rows(msdm_rf_performance) %>% 
+  dplyr::filter(fold_eval <= 3) %>% # TODO use all folds
   dplyr::group_by(species, model, metric) %>% 
   dplyr::summarise(value = mean(value, na.rm = F)) %>% 
   dplyr::mutate(
@@ -65,8 +70,7 @@ performance = ssdm_results %>%
     )
   )
 
-
-focal_metrics = unique(performance$metric)
+focal_metrics = c("accuracy", "auc", "f1", "kappa", "precision", "recall")
 
 plot_performance = function(df_plot, metric, x_var, x_label, x_log = T, reg_func = lin_regression) {
   df_plot = dplyr::filter(df_plot, metric == !!metric) %>% 
@@ -129,7 +133,7 @@ plot_performance = function(df_plot, metric, x_var, x_label, x_log = T, reg_func
 
 ## Summary
 
-This document summarizes the performance of different sSDM and mSDM algorithms for `r I(length(unique(performance$species)))` South American mammal species. Model performance is evaluated on `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) and analyzed along five potential influence factors (number of records, range size, range coverage, range coverage bias, and functional group). The comparison of sSDM vs mSDM approaches is of particular interest.
+This document summarizes the performance of different sSDM and mSDM algorithms for `r I(length(unique(performance$species)))` South American mammal species. Model performance is evaluated on `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) and analyzed along five potential influence factors (number of records, range size, range coverage, range coverage bias). The comparison of sSDM vs mSDM approaches is of particular interest.
 
 Code can be found on [GitLab](https://git.idiv.de/ye87zine/symobio-modeling).
 
@@ -141,9 +145,6 @@ Code can be found on [GitLab](https://git.idiv.de/ye87zine/symobio-modeling).
 -   Absence sampling from probability surface that balances geographical sampling bias and environmental preferences per species
     -   higher probability in areas that have been sampled more intensely
     -   lower probability in areas with environmental conditions similar to presence locations
-
-![absence sampling sampling](images/Akodon%20boliviensis.pdf){width="100%" height="800"}
-
 -   Eight predictors:
     -   Min Temperature of Coldest Month (bio6)
     -   Precipitation of Driest Quarter (bio17)
@@ -159,38 +160,27 @@ Code can be found on [GitLab](https://git.idiv.de/ye87zine/symobio-modeling).
 
 #### sSDM Algorithms
 
--   Four Algorithms: Random Forest (RF), Gradient Boosting Machine (GBM), Generalized Additive Model (GAM), Neural Network (NN)
+-   Four algorithms: Random Forest (RF), Gradient Boosting Machine (GBM), Generalized Additive Model (GAM), Neural Network (NN)
 
 -   NN: Manual hyperparameter tuning, same settings across species
 
--   RF + GBM + GAM: Automated hyperparameter tuning (4 random combinations) per species
+-   RF + GBM + GAM: Automated hyperparameter tuning (8 random combinations) per species
 
 #### mSDM Algorithms
 
-Multispecies Neural Network with species embedding (**MSDM_embed**)
+-   Three algorithms: Random Forest (MSDM_rf), Neural Network (MSDM_embed, MSDM_onehot)
 
--   prediction: probability of occurrence given a set of (environmental) inputs and species identity
--   embedding initialized at random
--   three hidden layers, sigmoid + leaky ReLu activations, binomial loss
+-   Species identity part of the input data, internal representation then either as onehot vector (MSDM_rf, MSDM_onehot) or via embedding (MSDM_embed)
 
 ### Key findings:
 
-TBA
+- MSDM algorithms score much higher across all performance algorithms
+- Among MSDM algorithms, RF outperforms NNs significantly
 
 ## Analysis
 
-The table below shows the analysed modeling results.
-
-```{r performance, echo = FALSE, message=FALSE, warnings=FALSE}
-DT::datatable(performance) %>% 
-  formatRound(columns="value", digits=3)
-```
-
 ### Number of records
 
--   Model performance was generally better for species with more observations
--   Very poor performance below 50-100 observations
-
 ```{r, echo = FALSE, message=FALSE, warnings=FALSE}
 df_plot = performance %>% 
   dplyr::left_join(obs_count, by = "species") 
@@ -373,3 +363,68 @@ plot = plot_performance(df_plot, metric = "recall", x_var = "range_cov", x_label
 bslib::card(plot, full_screen = T)
 ```
 :::
+
+
+### Range coverage bias
+
+Range coverage bias was calculated as 1 minus the ratio of the actual range coverage and the hypothetical range coverage if all observations were maximally spread out across the range.
+
+$$
+RangeCoverageBias = 1 - \frac{RangeCoverage}{min({N_{obs\_total}} / {N_{cells\_total}}, 1)}
+$$
+
+Higher bias values indicate that occurrence records are spatially more clustered within the range of the species.
+
+```{r range_coverage_bias, echo = FALSE, message=FALSE, warnings=FALSE}
+occs_final_unproj = occs_final %>% sf::st_transform("EPSG:4326")
+
+df_occs_total = occs_final_unproj %>% 
+  st_drop_geometry() %>% 
+  group_by(species) %>% 
+  summarise(occs_total = n())
+
+df_join = df_occs_total %>% 
+  dplyr::inner_join(df_cells_total, by = "species") %>% 
+  dplyr::inner_join(df_cells_occ, by = "species") %>% 
+  dplyr::mutate(range_bias = 1-((cells_occupied / cells_total) / pmin(occs_total / cells_total, 1)))
+
+df_plot = performance %>% 
+  inner_join(df_join, by = "species")
+```
+
+::: panel-tabset
+#### AUC
+
+```{r echo = FALSE}
+plot = plot_performance(df_plot, metric = "auc", x_var = "range_bias", x_label = "Range coverage bias", x_log = F)
+bslib::card(plot, full_screen = T)
+```
+
+#### F1
+
+```{r echo = FALSE}
+plot = plot_performance(df_plot, metric = "f1", x_var = "range_bias", x_label = "Range coverage bias", x_log = F)
+bslib::card(plot, full_screen = T)
+```
+
+#### Accuracy
+
+```{r echo = FALSE}
+plot = plot_performance(df_plot, metric = "accuracy", x_var = "range_bias", x_label = "Range coverage bias", x_log = F)
+bslib::card(plot, full_screen = T)
+```
+
+#### Precision
+
+```{r echo = FALSE}
+plot = plot_performance(df_plot, metric = "precision", x_var = "range_bias", x_label = "Range coverage bias", x_log = F)
+bslib::card(plot, full_screen = T)
+```
+
+#### Recall
+
+```{r echo = FALSE}
+plot = plot_performance(df_plot, metric = "recall", x_var = "range_bias", x_label = "Range coverage bias", x_log = F)
+bslib::card(plot, full_screen = T)
+```
+:::
\ No newline at end of file
diff --git a/R/05_02_publication_analysis.R b/R/05_02_publication_analysis.R
index 19856d7..ee481d3 100644
--- a/R/05_02_publication_analysis.R
+++ b/R/05_02_publication_analysis.R
@@ -16,6 +16,7 @@ source("R/utils.R")
 load("data/r_objects/model_data.RData")
 load("data/r_objects/ssdm_results.RData")
 load("data/r_objects/msdm_embed_performance.RData")
+load("data/r_objects/msdm_onehot_performance.RData")
 load("data/r_objects/functional_groups.RData")
 
 load("data/r_objects/sa_polygon.RData")
@@ -32,10 +33,12 @@ model_data = model_data %>%
 # ------------------------------------------------------------------ #
 performance = ssdm_results %>% 
   bind_rows(msdm_embed_performance) %>% 
+  bind_rows(msdm_onehot_performance) %>% 
   dplyr::group_by(species, model, metric) %>% 
   dplyr::summarise(value = mean(value, na.rm = F)) %>% 
   dplyr::mutate(
-    metric = stringr::str_to_lower(metric),
+    metric = factor(tolower(metric), levels = c("auc", "f1", "kappa", "accuracy", "precision", "recall")),
+    model = factor(model, levels = c("GAM", "GBM", "RF", "NN", "MSDM_onehot", "MSDM_embed")),
     value = case_when(
       ((is.na(value) | is.nan(value)) & metric %in% c("auc", "f1", "accuracy", "precision", "recall")) ~ 0.5,
       ((is.na(value) | is.nan(value)) & metric %in% c("kappa")) ~ 0,
@@ -85,7 +88,7 @@ df_plot = performance %>%
   dplyr::left_join(obs_count, by = "species") 
 
 ggplot(df_plot, aes(x = obs, y = value, color = model, fill = model)) +
-  geom_point(alpha = 0.2) +
+  geom_point(alpha = 0.1) +
   geom_smooth() + 
   facet_wrap(~ metric, scales = "free_y") +
   scale_x_continuous(trans = "log10") +
@@ -148,10 +151,10 @@ library(caret)
 library(gam)
 library(gbm)
 library(cito)
-library(randomForest)
+library(ranger)
 
 # Define plotting function
-plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", "gbm", "rf", "nn", "msdm")){
+plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", "gbm", "rf", "nn", "msdm_embed", "msdm_onehot")){
   # Species data
   load("data/r_objects/range_maps.RData")
   
@@ -163,7 +166,7 @@ plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam",
   
   # Extract raster values into df
   bbox_spec = sf::st_bbox(pa_spec) %>% 
-    expand_bbox(expansion = 0.25)
+    expand_bbox(expansion = 0.75)
   
   raster_crop = terra::crop(raster_data, bbox_spec)
   new_data = raster_crop %>% 
@@ -174,22 +177,24 @@ plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam",
   for(algorithm in algorithms){
     message(algorithm)
     # Load model
-    tryCatch({
-      if(algorithm == "msdm"){
-        load("data/r_objects/msdm_embed_results/msdm_embed_fit_test.RData")
-        probabilities = predict(msdm_embed_fit, new_data, type = "response")[,1]
-        predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
-      } else if(algorithm == "nn") {
-        load(paste0("data/r_objects/ssdm_results/models/", spec, "_nn_fit.RData"))
-        probabilities = predict(nn_fit, new_data, type = "response")[,1]
-        predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
-      } else {
-        load(paste0("data/r_objects/ssdm_results/models/", spec, "_", algorithm, "_fit.RData"))
-        predictions = predict(get(paste0(algorithm, "_fit")), new_data, type = "raw")
-      }
-    }, error = function(e){
-      warning(toupper(algorithm), ": Model could not be loaded.")
-    })
+    if(algorithm == "msdm_onehot"){
+      load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
+      probabilities = predict(msdm_onehot_fit, new_data, type = "response")[,1]
+      predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
+    } else if(algorithm == "msdm_embed"){
+      load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
+      probabilities = predict(msdm_embed_fit, new_data, type = "response")[,1]
+      predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
+    } else if(algorithm == "nn") {
+      load(paste0("data/r_objects/ssdm_results/full_models/", spec, "_nn_fit.RData"))
+      new_data_tmp = dplyr::select(new_data, -species)
+      probabilities = predict(nn_fit, new_data_tmp, type = "response")[,1]
+      predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
+    } else {
+      load(paste0("data/r_objects/ssdm_results/full_models/", spec, "_", algorithm, "_fit.RData"))
+      new_data_tmp = dplyr::select(new_data, -species)
+      predictions = predict(get(paste0(algorithm, "_fit")), new_data_tmp, type = "raw")
+    }
     
     raster_pred = terra::rast(raster_crop, nlyrs = 1)
     values(raster_pred)[as.integer(rownames(new_data))] <- predictions
@@ -199,59 +204,76 @@ plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam",
          type = "classes",
          levels = c("Absent", "Present"),
          main = paste0("Range prediction (", toupper(algorithm), "): ", spec))
-    point_colors = sapply(pa_spec$present, function(x) ifelse(x == 0, "black", "white"))
-    plot(pa_spec[,"present"], col = point_colors, add = T, pch = 16)
-    plot(range_spec, border = "white", lwd = 2, col = NA, add = T)
+    point_colors = sapply(pa_spec$present, function(x) ifelse(x == 0, "#000000AA", "#FFFFFFAA"))
+    plot(pa_spec[,"present"], col = point_colors, add = T, pch = 16, cex = 0.7)
+    plot(range_spec, border = "red", lwd = 1.5, col = NA, add = T)
   }
 }
 
 # Load raster
-raster_filepaths = list.files("~/symobio-modeling/data/geospatial/raster/", full.names = T)
+raster_filepaths = list.files("~/symobio-modeling/data/geospatial/raster/", full.names = T) %>% 
+  stringr::str_sort(numeric = T) 
+
 raster_data = terra::rast(raster_filepaths) %>% 
   terra::crop(sf::st_bbox(sa_polygon)) %>% 
   terra::project(sf::st_crs(model_data)$input)
 
-spec = "Abrothrix andinus"
-pdf(file = paste0("plots/range_predictions/", spec, ".pdf"))
-plot_predictions(spec, model_data, raster_data, algorithms = c("gam", "gbm", "rf", "msdm"))
-dev.off()
+specs = sort(sample(levels(model_data$species), 20))
+for(spec in specs){
+  pdf(file = paste0("plots/range_predictions/", spec, ".pdf"))
+  tryCatch({
+    plot_predictions(spec, model_data, raster_data, algorithms = c("gam", "xgb", "rf", "msdm_embed", "msdm_onehot"))
+  }, finally = {
+    dev.off()
+  })
+}
 
 # ------------------------------------------------------------------ #
-# 4. Compare msdm predictions                                     ####
+# 4. Compare msdm predictions (embed vs. onehot)                  ####
 # ------------------------------------------------------------------ #
-# Check predictions for different species
-load("data/r_objects/msdm_embed_results/msdm_embed_fit_test.RData")
-specs = sample(unique(model_data$species), size = 2, replace = F)
-
-new_data = spatSample(raster_data, 10, replace = F, as.df = T) %>% 
-  drop_na()
-
-new_data1 = dplyr::mutate(new_data, species = specs[1])
-new_data2 = dplyr::mutate(new_data, species = specs[2])
-
-pred1 = predict(msdm_embed_fit, new_data1, type = "response")
-pred2 = predict(msdm_embed_fit, new_data2, type = "response")
-
-pred1
-pred2
-
-# ---> identical predictions
+load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
+load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
+
+plot_embed_vs_onehot = function(){
+  spec = sample(unique(model_data$species), 1)
+  new_data = spatSample(raster_data, 100, replace = F, as.df = T) %>% 
+    drop_na() %>% 
+    dplyr::mutate(species = spec)
+  
+  p1_em = predict(msdm_embed_fit, new_data, type = "response")
+  p1_oh = predict(msdm_onehot_fit, new_data, type = "response")
+  
+  # Compare onehot vs embed
+  plot(x = p1_em, y = p1_oh, xlab = "embed", ylab = "onehot", xlim=c(0,1), ylim = c(0,1), pch = 16, main = spec)
+  abline(a = 0, b = 1)
+  text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red")
+  abline(h=0.5)
+  abline(v=0.5)
+}
 
-# Permute embeddings
-embeddings = coef(msdm_embed_fit)[[1]][[1]]
-msdm_embed_fit$net$e_1$parameters$weight$set_data(torch::torch_tensor(embeddings[sample.int(nrow(embeddings)), ], device = msdm_embed_fit$net$e_1$parameters$weight$device, dtype = msdm_embed_fit$net$e_1$parameters$weight$dtype ))
+plot_embed_vs_onehot()
 
-# ---> Invalid external pointer
+## --> Predictions are similar between msdms
+## --> Onehot model seems to predict slightly more presences than embedded model (???)
 
 # ------------------------------------------------------------------ #
-# 5. Compare species embeddings                                   ####
+# 5. Compare species embeddings across folds                      ####
 # ------------------------------------------------------------------ #
+obs_count = model_data %>% 
+  sf::st_drop_geometry() %>% 
+  dplyr::filter(present == 1, !is.na(fold_eval)) %>% 
+  dplyr::group_by(species) %>% 
+  dplyr::summarise(obs = n())
+
 all_embeddings = lapply(1:5, function(fold){
   load(paste0("data/r_objects/msdm_embed_results/msdm_embed_fit_fold", fold, ".RData"))
   coef(msdm_embed_fit)[[1]][[1]]
 })
 
 pairs = combn(1:5, 2, simplify = F)
+
+## Correlations
+### Embedding matrices across folds ####
 pairwise_rv = lapply(pairs, function(pair){
   FactoMineR::coeffRV(all_embeddings[[pair[1]]], all_embeddings[[pair[2]]])  
 })
@@ -266,25 +288,93 @@ for(i in seq_along(pairs)){
   p_matrix[pairs[[i]][2], pairs[[i]][1]] = pairwise_rv[[i]][["p.value"]]
 }
 
+# --> Embeddings from different folds are highly correlated
+# --> Models seem to learn consistent underlying structure in the data
+
+## Embedding vectors across folds ####
+r_embed = sapply(unique(model_data$species), function(spec){
+  sapply(pairs, function(pair){
+    cor(
+      all_embeddings[[pair[1]]][as.integer(spec),], 
+      all_embeddings[[pair[2]]][as.integer(spec),]
+    )
+  })
+}, simplify = F)
+
+r_embed_df = data.frame(
+  species = unique(model_data$species),
+  r_mean = sapply(r_embed, mean),
+  r_var = sapply(r_embed, var)
+) %>% 
+  tidyr::pivot_longer(cols = c(r_mean, r_var), names_to = "stat", values_to = "value")
+
+df_plot = obs_count %>% 
+  dplyr::left_join(r_embed_df, by = "species")
+
+ggplot(data = df_plot, aes(x = obs, y = value)) +
+  geom_point() +
+  geom_smooth() +
+  labs(title = "Mean and variance of correlation coefficient of species embeddings across folds") +
+  scale_x_continuous(transform = "log10") + 
+  facet_wrap("stat") +
+  theme_minimal()
+
+# --> Individual species embeddings are not correlated across folds
+# --> Embeddings rotated?
+
+## Pairwise distances of embedding vectors across folds ####
+all_embeddings_dist = lapply(all_embeddings, function(e){
+  as.matrix(dist(e, method = "euclidean"))
+})
+
+r_dist = sapply(unique(model_data$species), function(spec){
+  sapply(pairs, function(pair){
+    cor(
+      all_embeddings_dist[[pair[1]]][as.integer(spec),], 
+      all_embeddings_dist[[pair[2]]][as.integer(spec),]
+    )
+  })
+}, simplify = F)
+
+r_dist_df = data.frame(
+  species = unique(model_data$species),
+  r_mean = sapply(r_dist, mean),
+  r_var = sapply(r_dist, var)
+) %>% 
+  tidyr::pivot_longer(cols = c(r_mean, r_var), names_to = "stat", values_to = "value")
+
+df_plot = obs_count %>% 
+  dplyr::left_join(r_dist_df, by = "species")
+
+ggplot(data = df_plot, aes(x = obs, y = value)) +
+  geom_point() +
+  geom_smooth() +
+  labs(title = "Mean and variance of species' pairwise distances in embedding space across folds") +
+  scale_x_continuous(transform = "log10") + 
+  facet_wrap("stat") +
+  theme_minimal()
+
+## --> Pairwise distances are highly coirrelated across folds
+## --> Distance vectors for abundant species tend to be more similar across folds
+## --> Some potential for informed embeddings with respect to rare species?
+
 # ------------------------------------------------------------------ #
-# 6. Analyse species embeddings                                      ####
+# 6. Analyse embedding of full model                              ####
 # ------------------------------------------------------------------ #
-embeddings = coef(msdm_embed_fit)[[1]][[1]]
-rownames(embeddings) = levels(species_lookup$species)
-
+load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
 species_lookup = data.frame(species = levels(model_data$species)) %>% 
   dplyr::mutate(
     genus = stringr::str_extract(species, "([a-zA-Z]+) (.*)", group = 1)
   )
 
+embeddings = coef(msdm_embed_fit)[[1]][[1]]
+rownames(embeddings) = levels(species_lookup$species)
+
 ##  Dimensionality reduction        ####
 ### PCA                             ####
 pca_result = prcomp(t(embeddings), scale. = TRUE)
 var_explained = pca_result$sdev^2 / sum(pca_result$sdev^2)
-plot(var_explained)
-
-# --> Variance explained is distributed rather evenly across dimensions
-# --> Dimensionality reduction probably not useful
+plot(var_explained) # --> First two dimensions explain ~40% of variance
 
 coords = pca_result$rotation[,1:2] %>%
   magrittr::set_colnames(c("X", "Y"))
@@ -296,11 +386,10 @@ df_plot = species_lookup %>%
 ggplot(df_plot, aes(x = X, y = Y, col = functional_group, label=genus)) +
   geom_point() +
   geom_text(hjust=0, vjust=0) +
-  guides(col="none")
-
+  theme_minimal()
 
-### TSNE                            ####
-tsne_result <- Rtsne(embeddings, verbose = FALSE)
+### T-SNE                            ####
+tsne_result <- Rtsne(embeddings, verbose = TRUE)
 coords = tsne_result$Y %>%
   magrittr::set_colnames(c("X", "Y"))
 
@@ -311,13 +400,30 @@ df_plot = species_lookup %>%
 ggplot(df_plot, aes(x = X, y = Y, col = functional_group, label=genus)) +
   geom_point() +
   geom_text(hjust=0, vjust=0) +
-  guides(col="none")
+  labs(
+    title = "T-SNE Ordination of species embeddings",
+    color = "Functional group",
+  ) +
+  theme_minimal(base_size = 14) +
+  theme(
+    strip.text = element_text(face = "bold"),
+    axis.text.x = element_text(angle = 45, hjust = 1),
+    legend.position = "right"
+  )
+
+ggsave("plots/publication/tsne_ordination.pdf", 
+       device = "pdf", 
+       scale = 2,
+       width = 20, 
+       height = 18,
+       units = "cm")
 
 ### KNN                             ####
+rownames(embeddings) = species_lookup$species
 embeddings_dist = as.matrix(dist(embeddings))
 
 k = 9
-knn_results = sapply(as.integer(species_lookup$species), FUN = function(spec){
+knn_results = sapply(species_lookup$species, FUN = function(spec){
   return(sort(embeddings_dist[spec,])[2:(k+1)])
 }, USE.NAMES = T, simplify = F)
 
@@ -346,12 +452,14 @@ plot_knn_ranges = function(spec){
 
 plot_knn_ranges(sample(species_lookup$species, 1)) # Repeat this line to plot random species
 
-# Clustering
+## Clustering             ####
+rownames(embeddings) = species_lookup$species
 embeddings_dist = dist(embeddings, method = "euclidean")
 hclust_complete = hclust(embeddings_dist, method = "complete")
-plot(hclust_complete, main = "Complete Linkage", xlab = "", sub = "", labels = FALSE)
+plot(hclust_complete, main = "Complete Linkage", xlab = "", sub = "", cex = 0.5)
 
-# Phylo Correlation
+## Correlations           ####
+### Phylo dist            ####
 load("data/r_objects/phylo_dist.RData")
 
 species_intersect = intersect(colnames(phylo_dist), species_lookup$species)
@@ -362,7 +470,7 @@ e_indices = match(species_intersect, species_lookup$species)
 embeddings_dist_subset = as.matrix(embeddings_dist)[e_indices, e_indices]
 FactoMineR::coeffRV(embeddings_dist_subset, phylo_dist_subset)
 
-# Trait Correlation
+### Functional dist       ####
 load("data/r_objects/func_dist.RData")
 
 species_intersect = intersect(colnames(func_dist), species_lookup$species)
@@ -373,7 +481,7 @@ e_indices = match(species_intersect, species_lookup$species)
 embeddings_dist_subset = as.matrix(embeddings_dist)[e_indices, e_indices]
 FactoMineR::coeffRV(embeddings_dist_subset, func_dist_subset)
 
-# Range Correlation
+### Range dist            ####
 load("data/r_objects/range_dist.RData")
 
 species_intersect = intersect(colnames(range_dist), species_lookup$species)
@@ -383,4 +491,3 @@ range_dist_subset = range_dist[species_intersect, species_intersect]
 e_indices = match(species_intersect, species_lookup$species)
 embeddings_dist_subset = as.matrix(embeddings_dist)[e_indices, e_indices]
 FactoMineR::coeffRV(embeddings_dist_subset, range_dist_subset)
-
diff --git a/R/_publish.yml b/R/_publish.yml
index 9d5a5d0..91c7c69 100644
--- a/R/_publish.yml
+++ b/R/_publish.yml
@@ -10,3 +10,7 @@
   quarto-pub:
     - id: cfaa8a4d-31e3-410c-ba1a-399aad5582fd
       url: 'https://chrkoenig.quarto.pub/sdm-performance-analysis-8272'
+- source: 05_01_performance_report.qmd
+  quarto-pub:
+    - id: 25ccfe7b-b9d4-4bc9-bbf8-823676aab7bd
+      url: 'https://chrkoenig.quarto.pub/sdm-performance-report-f08b'
diff --git a/snippets.R b/snippets.R
deleted file mode 100644
index 40f5523..0000000
--- a/snippets.R
+++ /dev/null
@@ -1,160 +0,0 @@
-### Range coverage bias
-
-Range coverage bias was calculated as 1 minus the ratio of the actual range coverage and the hypothetical range coverage if all observations were maximally spread out across the range.
-
-$$
-  RangeCoverageBias = 1 - \frac{RangeCoverage}{min({N_{obs\_total}} / {N_{cells\_total}}, 1)}
-  $$
-    
-    Higher bias values indicate that occurrence records are spatially more clustered within the range of the species.
-  
-  -   There was no strong relationship between range coverage bias and model performance
-  
-  ```{r range_coverage_bias, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
-  df_occs_total = occs_final %>% 
-    st_drop_geometry() %>% 
-    group_by(species) %>% 
-    summarise(occs_total = n())
-  
-  df_join = df_occs_total %>% 
-    dplyr::inner_join(df_cells_total, by = "species") %>% 
-    dplyr::inner_join(df_cells_occ, by = "species") %>% 
-    dplyr::mutate(range_bias = 1-((cells_occupied / cells_total) / pmin(occs_total / cells_total, 1)))
-  
-  df_plot = performance %>% 
-    inner_join(df_join, by = "species")
-  
-  # Calculate regression lines for each model and metric combination
-  suppressWarnings({
-    regression_lines = df_plot %>%
-      group_by(model, metric) %>%
-      group_modify(~lin_regression(.x$range_bias, .x$value))
-  })
-  
-  # Create base plot
-  plot <- plot_ly() %>% 
-    layout(
-      title = "Model Performance vs. Range coverage bias",
-      xaxis = list(title = "Range coverage bias"),
-      yaxis = list(title = "Value"),
-      legend = list(x = 1.1, y = 0.5),  # Move legend to the right of the plot
-      margin = list(r = 150),  # Add right margin to accommodate legend
-      hovermode = 'closest',
-      updatemenus = list(
-        list(
-          type = "dropdown",
-          active = 0,
-          buttons = plotly_buttons
-        )
-      )
-    )
-  
-  # Add regression lines and confidence intervals for each model
-  for (model_name in unique(df_plot$model)) {
-    plot <- plot %>%
-      add_markers(
-        data = filter(df_plot, model == model_name),
-        x = ~ range_bias,
-        y = ~ value,
-        color = model_name,  # Set color to match legendgroup
-        legendgroup = model_name,
-        opacity = 0.6,
-        name = ~model,
-        hoverinfo = 'text',
-        text = ~paste("Species:", species, "<br>Range coverage bias:", range_bias, "<br>Value:", round(value, 3)),
-        transforms = list(
-          list(
-            type = 'filter',
-            target = ~metric,
-            operation = '=',
-            value = focal_metrics[1]
-          )
-        )
-      )
-  }
-  
-  for (model_name in unique(df_plot$model)) {
-    reg_data = dplyr::filter(regression_lines, model == model_name)
-    plot = plot %>% 
-      add_lines(
-        data = reg_data,
-        x = ~x,
-        y = ~fit,
-        color = model_name,  # Set color to match legendgroup
-        legendgroup = model_name,
-        name = paste(model_name, '(fit)'),
-        showlegend = FALSE,
-        transforms = list(
-          list(
-            type = 'filter',
-            target = ~metric,
-            operation = '=',
-            value = focal_metrics[1]
-          )
-        )
-      )
-  }
-  
-  
-  bslib::card(plot, full_screen = T)
-  ```
-  
-  ### Functional group
-  
-  Functional groups were assigned based on taxonomic order. The following groupings were used:
-    
-    | Functional group      | Taxomic orders                                                        |
-    |-------------------|-----------------------------------------------------|
-    | large ground-dwelling | Carnivora, Artiodactyla, Cingulata, Perissodactyla                    |
-    | small ground-dwelling | Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha |
-    | arboreal              | Primates, Pilosa                                                      |
-    | flying                | Chiroptera                                                            |
-    
-    -   Models for bats tended to perform slightly worse than for other groups.
-  
-  ```{r functional_groups, echo = FALSE, message=FALSE, warnings=FALSE, eval=F}
-  df_plot = performance %>% 
-    dplyr::left_join(functional_groups, by = c("species" = "name_matched"))
-  
-  plot <- plot_ly(
-    data = df_plot,
-    x = ~functional_group,
-    y = ~value,
-    color = ~model,
-    type = 'box',
-    boxpoints = "all",
-    jitter = 1,
-    pointpos = 0,
-    hoverinfo = 'text',
-    text = ~paste("Species:", species, "<br>Functional group:", functional_group, "<br>Value:", round(value, 3)),
-    transforms = list(
-      list(
-        type = 'filter',
-        target = ~metric,
-        operation = '=',
-        value = focal_metrics[1]  # default value
-      )
-    )
-  )
-  
-  plot <- plot %>%
-    layout(
-      title = "Model Performance vs. Functional Group",
-      xaxis = list(title = "Functional group"),
-      yaxis = list(title = "Value"),
-      legend = list(x = 1.1, y = 0.5),  # Move legend to the right of the plot
-      margin = list(r = 150),  # Add right margin to accommodate legend
-      hovermode = 'closest',
-      boxmode = "group",
-      updatemenus = list(
-        list(
-          type = "dropdown",
-          active = 0,
-          buttons = plotly_buttons
-        )
-      )
-    )
-  
-  bslib::card(plot, full_screen = T)
-  ```
-  
\ No newline at end of file
-- 
GitLab