From 11bd13c837559b755914e114860548ee926629c0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=B6nig?= <ye87zine@usr.idiv.de>
Date: Wed, 6 Nov 2024 10:27:53 +0100
Subject: [PATCH] implement cito DNNs for SSDM, restructure files

---
 ...ange_maps.R => 01_range_map_preparation.R} |   0
 ...ps.R => 02_functional_group_preparation.R} |   0
 R/03_presence_absence_preparation.R           | 208 ++++++++++++++++++
 R/04_01_ssdm_modeling.R                       | 156 +++++++++++++
 ...t_species.R => 04_distribution_modeling.R} | 161 ++++----------
 ...alysis.qmd => 05_performance_analysis.qmd} | 110 ++++++++-
 R/install_torch.R                             |  88 --------
 R/utils.R                                     |  41 +++-
 8 files changed, 530 insertions(+), 234 deletions(-)
 rename R/{process_range_maps.R => 01_range_map_preparation.R} (100%)
 rename R/{assign_functional_groups.R => 02_functional_group_preparation.R} (100%)
 create mode 100644 R/03_presence_absence_preparation.R
 create mode 100644 R/04_01_ssdm_modeling.R
 rename R/{model_target_species.R => 04_distribution_modeling.R} (51%)
 rename R/{performance_analysis.qmd => 05_performance_analysis.qmd} (80%)
 delete mode 100644 R/install_torch.R

diff --git a/R/process_range_maps.R b/R/01_range_map_preparation.R
similarity index 100%
rename from R/process_range_maps.R
rename to R/01_range_map_preparation.R
diff --git a/R/assign_functional_groups.R b/R/02_functional_group_preparation.R
similarity index 100%
rename from R/assign_functional_groups.R
rename to R/02_functional_group_preparation.R
diff --git a/R/03_presence_absence_preparation.R b/R/03_presence_absence_preparation.R
new file mode 100644
index 0000000..fe78cea
--- /dev/null
+++ b/R/03_presence_absence_preparation.R
@@ -0,0 +1,208 @@
+# General packages
+library(dplyr)
+library(tidyr)
+library(furrr)
+
+# Geo packages
+library(terra)
+library(CoordinateCleaner)
+library(sf)
+
+# DB packages
+library(Symobio)
+library(DBI)
+
+source("R/utils.R")
+
+con = db_connect() # Connection to Symobio
+sf::sf_use_s2(use_s2 = FALSE) # Don't use spherical geometry
+
+# ---------------------------------------------------------------------------#
+# Prepare Geodata                                                         ####
+# ---------------------------------------------------------------------------#
+raster_info = tbl(con, "datasets") %>% 
+  dplyr::filter(stringr::str_detect(name, "CHELSA")) %>% 
+  collect()
+
+# raster_filepaths = list.files(raster_info$raw_data_path, pattern = ".tif$", full.names = T) %>% 
+#   stringr::str_sort(numeric = T)
+
+raster_filepaths = list.files("I:/mas_data/00_data/processed/CHELSA_v2-1_bioclim", pattern = ".tif$", full.names = T) %>% 
+  stringr::str_sort(numeric = T)
+
+sa_polygon = rnaturalearth::ne_countries() %>% 
+  dplyr::filter(continent == "South America") %>% 
+  sf::st_union()
+
+# ---------------------------------------------------------------------------#
+# Prepare Occurrence Data                                                 ####
+# ---------------------------------------------------------------------------#
+load("data/r_objects/range_maps.RData")
+target_species = unique(range_maps$name_matched[!is.na(range_maps$name_matched)])
+
+occs = tbl(con, "species_occurrences") %>% 
+  dplyr::filter(species %in% target_species) %>% 
+  dplyr::select(-year) %>% 
+  dplyr::distinct() %>% 
+  collect() %>% 
+  sf::st_as_sf(coords = c("longitude", "latitude"), remove = F, crs = sf::st_crs(4326)) %>% 
+  sf::st_filter(sa_polygon)
+
+occs_flagged = occs %>% 
+  dplyr::distinct(species, coordinate_id, longitude, latitude) %>% 
+  group_by(species) %>% 
+  group_split() %>% 
+  purrr::map(                     # Loop over species individually due to bug in CoordinateCleaner
+    CoordinateCleaner::clean_coordinates,
+    lon = "longitude", 
+    lat = "latitude",
+    tests = c("centroids", "gbif", "equal", "institutions", "outliers"),
+    outliers_method = "quantile",
+    verbose = F
+  ) %>% 
+  bind_rows() %>% 
+  dplyr::filter(.summary == T) %>% 
+  dplyr::select(species, coordinate_id, longitude, latitude)
+
+env_vars = tbl(con, "raster_extracts_num") %>% 
+  dplyr::filter(
+    coordinate_id %in% occs$coordinate_id,
+    metric == "mean"
+  ) %>% 
+  arrange(raster_layer_id) %>% 
+  tidyr::pivot_wider(id_cols = coordinate_id, names_from = raster_layer_id, names_prefix = "layer_", values_from = value) %>% 
+  collect()
+
+occs_final = occs %>% 
+  inner_join(occs_flagged, by = c("species", "coordinate_id", "longitude", "latitude")) %>% 
+  inner_join(env_vars, by = "coordinate_id") %>% 
+  dplyr::select(-coordinate_id)
+
+save(occs_final, file = "data/r_objects/occs_final.RData")
+
+# ---------------------------------------------------------------------------#
+# Create SSDM dataset                                                     ####
+# ---------------------------------------------------------------------------#
+load("data/r_objects/occs_final.RData")
+load("data/r_objects/sa_polygon.RData")
+sf::sf_use_s2(use_s2 = FALSE)
+
+occs_split = split(occs_final, occs_final$species)
+
+future::plan("multisession", workers = 8)
+ssdm_data = furrr::future_map(occs_split, .progress = TRUE, .options = furrr::furrr_options(seed = 42), .f = function(occs_spec){
+  # skip low information species
+  if(nrow(occs_spec) < 5){
+    return(NULL)
+  }
+  
+  # Define model/sampling region
+  occs_bbox = occs_spec %>% 
+    sf::st_bbox() %>% 
+    expand_bbox(min_span = 1, expansion = 0.25) %>% 
+    sf::st_as_sfc() %>% 
+    st_set_crs(st_crs(occs_spec)) 
+  
+  sample_region = suppressMessages(
+    st_as_sf(st_intersection(occs_bbox, sa_polygon))
+  )
+  
+  # Sample pseudo absence
+  sample_points = st_as_sf(st_sample(sample_region, nrow(occs_spec))) %>% 
+    dplyr::mutate(
+      species = occs_spec$species[1],
+      longitude = sf::st_coordinates(.)[,1],
+      latitude = sf::st_coordinates(.)[,2]
+    ) %>% 
+    dplyr::rename(geometry = "x")
+  
+  abs_spec = terra::rast(raster_filepaths) %>% 
+    setNames(paste0("layer_", 1:19)) %>% 
+    terra::extract(sample_points) %>% 
+    dplyr::select(-ID) %>% 
+    dplyr::mutate(
+      present = 0,
+      geometry = sample_points$x
+    ) %>% 
+    tibble() %>% 
+    bind_cols(sample_points)
+  
+  # Create presence-absence dataframe
+  pa_spec = occs_spec %>% 
+    dplyr::mutate(present = 1) %>% 
+    bind_rows(abs_spec) 
+  
+  # Define cross-validation folds
+  spatial_folds = blockCV::cv_spatial(
+    pa_spec,
+    column = "present",
+    k = 5,
+    progress = F, plot = F, report = F
+  )
+  
+  pa_spec$fold = spatial_folds$folds_ids
+  pa_spec$geometry = NULL
+  
+  # Split into train and test datasets
+  train_index = createDataPartition(pa_spec$present, p = 0.7, list = FALSE)
+  pa_spec$train = 0
+  pa_spec$train[train_index] = 1
+  
+  return(pa_spec)
+})
+
+ssdm_data = bind_rows(ssdm_data)
+save(ssdm_data, file = "data/r_objects/ssdm_data.RData")
+
+# ---------------------------------------------------------------------------#
+# Create MSDM dataset                                                     ####
+# ---------------------------------------------------------------------------#
+load("data/r_objects/occs_final.RData")
+load("data/r_objects/sa_polygon.RData")
+sf::sf_use_s2(use_s2 = FALSE)
+
+occs_split = split(occs_final, occs_final$species)
+
+future::plan("multisession", workers = 8)
+msdm_data = furrr::future_map(occs_split, .progress = TRUE, .options = furrr::furrr_options(seed = 42), .f = function(occs_spec){
+  # Define model/sampling region
+  occs_bbox = occs_spec %>% 
+    sf::st_bbox() %>% 
+    expand_bbox(min_span = 1, expansion = 0.25) %>% 
+    sf::st_as_sfc() %>% 
+    st_set_crs(st_crs(occs_spec)) 
+  
+  sample_region = suppressMessages(
+    st_as_sf(st_intersection(occs_bbox, sa_polygon))
+  )
+  
+  # Sample pseudo absence
+  sample_points = st_as_sf(st_sample(sample_region, nrow(occs_spec))) %>% 
+    dplyr::mutate(
+      species = occs_spec$species[1],
+      longitude = sf::st_coordinates(.)[,1],
+      latitude = sf::st_coordinates(.)[,2]
+    ) %>% 
+    dplyr::rename(geometry = "x")
+  
+  abs_spec = terra::rast(raster_filepaths) %>% 
+    setNames(paste0("layer_", 1:19)) %>% 
+    terra::extract(sample_points) %>% 
+    dplyr::select(-ID) %>% 
+    dplyr::mutate(
+      present = 0,
+      geometry = sample_points$x
+    ) %>% 
+    tibble() %>% 
+    bind_cols(sample_points)
+  
+  # Create presence-absence dataframe
+  pa_spec = occs_spec %>% 
+    dplyr::mutate(present = 1) %>% 
+    bind_rows(abs_spec) 
+  
+  return(pa_spec)
+})
+
+msdm_data = bind_rows(msdm_data)
+save(ssdm_data, file = "data/r_objects/msdm_data.RData")
diff --git a/R/04_01_ssdm_modeling.R b/R/04_01_ssdm_modeling.R
new file mode 100644
index 0000000..5e8b92b
--- /dev/null
+++ b/R/04_01_ssdm_modeling.R
@@ -0,0 +1,156 @@
+library(dplyr)
+library(tidyr)
+library(furrr)
+library(caret)
+library(cito)
+library(pROC)
+
+source("R/utils.R")
+
+load("data/r_objects/ssdm_data.RData")
+
+pa_split = split(ssdm_data, ssdm_data$species)
+
+# ----------------------------------------------------------------------#
+# Train models                                                       ####
+# ----------------------------------------------------------------------#
+future::plan("multisession", workers = 8)
+ssdm_results = furrr::future_map(pa_split, .options = furrr::furrr_options(seed = 123), .f = function(pa_spec){
+  # Initial check
+  if(nrow(pa_spec) < 10){
+    return(NULL)
+  }
+  
+  pa_spec$present_fct = factor(pa_spec$present, levels = c("0", "1"), labels = c("A", "P"))
+  train_data = dplyr::filter(pa_spec, train == 1)
+  test_data = dplyr::filter(pa_spec, train == 0)
+  
+  # Define empty result for performance eval
+  na_performance = list(    
+    AUC = NA,
+    Accuracy = NA,
+    Kappa = NA,
+    Precision = NA,
+    Recall = NA,
+    F1 = NA
+  )
+  
+  # Define predictors
+  predictors = paste0("layer_", 1:19)
+  
+  # Define caret training routine #####
+  index_train = lapply(unique(sort(train_data$fold)), function(x){
+    return(which(train_data$fold != x))
+  })
+  
+  train_ctrl = trainControl(
+    search = "grid",
+    classProbs = TRUE, 
+    index = index_train,
+    summaryFunction = twoClassSummary, 
+    savePredictions = "final"
+  )
+  
+  # Random Forest #####
+  rf_performance = tryCatch({
+    rf_grid = expand.grid(
+      mtry = c(3,7,11,15,19)                # Number of randomly selected predictors
+    )
+    
+    rf_fit = caret::train(
+      x = train_data[, predictors],
+      y = train_data$present_fct,
+      method = "rf",
+      metric = "ROC",
+      tuneGrid = rf_grid,
+      trControl = train_ctrl
+    )
+    evaluate_model(rf_fit, test_data)
+  }, error = function(e){
+    na_performance
+  })
+  
+  # Gradient Boosted Machine ####
+  gbm_performance = tryCatch({
+    gbm_grid <- expand.grid(
+      n.trees = c(100, 500, 1000, 1500),       # Higher number of boosting iterations
+      interaction.depth = c(3, 5, 7),          # Maximum depth of each tree
+      shrinkage = c(0.01, 0.005, 0.001),       # Lower learning rates
+      n.minobsinnode = c(10, 20)               # Minimum number of observations in nodes
+    )
+    
+    gbm_fit = train(
+      x = train_data[, predictors],
+      y = train_data$present_fct,
+      method = "gbm",
+      metric = "ROC",
+      verbose = F,
+      tuneGrid = gbm_grid,
+      trControl = train_ctrl
+    )
+    evaluate_model(gbm_fit, test_data)
+  }, error = function(e){
+    na_performance
+  })
+  
+  # Generalized additive Model ####
+  glm_performance = tryCatch({
+    glm_fit = train(
+      x = train_data[, predictors],
+      y = train_data$present_fct,
+      method = "glm",
+      family=binomial, 
+      metric = "ROC",
+      preProcess = c("center", "scale"),
+      trControl = train_ctrl
+    )
+    evaluate_model(glm_fit, test_data)
+  }, error = function(e){
+    na_performance
+  })
+  
+  # Neural Network ####
+  nn_performance = tryCatch({
+    nn_fit = dnn(
+      X = train_data[, predictors],
+      Y = train_data$present,
+      hidden = c(500L, 200L, 50L),
+      loss = "binomial",
+      activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+      epochs = 500L, 
+      lr = 0.02,   
+      baseloss=10,
+      batchsize=nrow(train_data)/4,
+      dropout = 0.1,  # Regularization 
+      optimizer = config_optimizer("adam", weight_decay = 0.001),
+      lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
+      early_stopping = 200, # stop training when validation loss does not decrease anymore
+      validation = 0.3, # used for early stopping and lr_scheduler 
+      device = "cuda",
+      bootstrap = 5
+    )
+    
+    evaluate_model(nn_fit, test_data)
+  }, error = function(e){
+    na_performance
+  })
+  
+  # Summarize results
+  performance_summary = tibble(
+    species = pa_spec$species[1],
+    obs = nrow(pa_spec),
+    model = c("RF", "GBM", "GLM", "NN"),
+    auc = c(rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC, nn_performance$AUC),
+    accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy, nn_performance$Accuracy),
+    kappa = c(rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa, nn_performance$Kappa),
+    precision = c(rf_performance$Precision, gbm_performance$Precision, glm_performance$Precision, nn_performance$Precision),
+    recall = c(rf_performance$Recall, gbm_performance$Recall, glm_performance$Recall, nn_performance$Recall),
+    f1 = c(rf_performance$F1, gbm_performance$F1, glm_performance$F1, nn_performance$F1)
+  )
+  
+  return(performance_summary)
+})
+
+ssdm_results = bind_rows(ssdm_results)
+
+save(ssdm_results, file = "data/r_objects/ssdm_results.RData")
\ No newline at end of file
diff --git a/R/model_target_species.R b/R/04_distribution_modeling.R
similarity index 51%
rename from R/model_target_species.R
rename to R/04_distribution_modeling.R
index 420d219..9baa270 100644
--- a/R/model_target_species.R
+++ b/R/04_distribution_modeling.R
@@ -1,87 +1,3 @@
-# General packages
-library(dplyr)
-library(tidyr)
-library(ggplot2)
-library(furrr)
-
-# Geo packages
-library(terra)
-library(CoordinateCleaner)
-library(sf)
-
-# DB packages
-library(Symobio)
-library(DBI)
-
-# Modeling packages
-library(caret)
-library(pROC)
-library(cito)
-
-source("R/utils.R")
-
-con = db_connect()
-sf::sf_use_s2(use_s2 = FALSE)
-
-# ---------------------------------------------------------------------------#
-# Prepare Geodata                                                         ####
-# ---------------------------------------------------------------------------#
-raster_info = tbl(con, "datasets") %>% 
-  dplyr::filter(stringr::str_detect(name, "CHELSA")) %>% 
-  collect()
-
-raster_filepaths = list.files(raster_info$raw_data_path, pattern = ".tif$", full.names = T) %>% 
-  stringr::str_sort(numeric = T)
-
-sa_polygon = rnaturalearth::ne_countries() %>% 
-  dplyr::filter(continent == "South America") %>% 
-  sf::st_union()
-
-# ---------------------------------------------------------------------------#
-# Prepare Occurrence Data                                                 ####
-# ---------------------------------------------------------------------------#
-load("data/r_objects/range_maps.RData")
-target_species = unique(range_maps$name_matched[!is.na(range_maps$name_matched)])
-
-occs = tbl(con, "species_occurrences") %>% 
-  dplyr::filter(species %in% target_species) %>% 
-  dplyr::select(-year) %>% 
-  dplyr::distinct() %>% 
-  collect() %>% 
-  sf::st_as_sf(coords = c("longitude", "latitude"), remove = F, crs = sf::st_crs(4326)) %>% 
-  sf::st_filter(sa_polygon)
-
-occs_flagged = occs %>% 
-  dplyr::distinct(species, coordinate_id, longitude, latitude) %>% 
-  group_by(species) %>% 
-  group_split() %>% 
-  purrr::map(                     # Loop over species individually due to bug in CoordinateCleaner
-    CoordinateCleaner::clean_coordinates,
-    lon = "longitude", 
-    lat = "latitude",
-    tests = c("centroids", "gbif", "institutions", "outliers"),
-    outliers_method = "quantile",
-    verbose = F
-  ) %>% 
-  bind_rows() %>% 
-  dplyr::filter(.summary == T) %>% 
-  dplyr::select(species, coordinate_id, longitude, latitude)
-
-env_vars = tbl(con, "raster_extracts_num") %>% 
-  dplyr::filter(
-    coordinate_id %in% occs$coordinate_id,
-    metric == "mean"
-  ) %>% 
-  arrange(raster_layer_id) %>% 
-  tidyr::pivot_wider(id_cols = coordinate_id, names_from = raster_layer_id, names_prefix = "layer_", values_from = value) %>% 
-  collect()
-
-occs_final = occs %>% 
-  inner_join(occs_flagged, by = c("species", "coordinate_id", "longitude", "latitude")) %>% 
-  inner_join(env_vars, by = "coordinate_id") %>% 
-  dplyr::select(-coordinate_id)
-
-save(occs_final, file = "data/r_objects/occs_final.RData")
 # ---------------------------------------------------------------------------#
 # Main loop                                                               ####
 # ---------------------------------------------------------------------------#
@@ -90,9 +6,9 @@ load("data/r_objects/sa_polygon.RData")
 
 occs_split = split(occs_final, occs_final$species)
 
-future::plan("multisession", workers = 16)
+future::plan("multisession", workers = 20)
 
-model_results = furrr::future_map(occs_split, .options = furrr::furrr_options(seed = 123), .f = function(occs_spec){
+ssdm_data = furrr::future_map(occs_split[1:20], .options = furrr::furrr_options(seed = 123), .f = function(occs_spec){
   # Initial check
   if(nrow(occs_spec) < 10){
     return(NULL)
@@ -181,35 +97,32 @@ model_results = furrr::future_map(occs_split, .options = furrr::furrr_options(se
   # Train models                 ####
   # ------------------------------- #
   ## cito ####
-  # model_data_nn = model_data %>% 
-  #   dplyr::mutate(across(all_of(predictors), scale)) %>% 
-  #   select(-species, -longitude, -latitude)
-  # 
-  # train_data_nn = model_data_nn[train_index, ]
-  # test_data_nn  = model_data_nn[-train_index, ]
-  # 
-  # nn_fit = dnn(
-  #   Y = train_data_nn$presence,
-  #   X = as.matrix(train_data_nn[, predictors]),
-  #   hidden = c(500L, 500L),
-  #   loss = "binomial",
-  #   activation = "leaky_relu",
-  #   epochs = 2000L, 
-  #   lr = 0.02,    
-  #   dropout = 0.2,  # Regularization 
-  #   burnin = Inf,
-  #   optimizer = config_optimizer("adam", weight_decay = 0.001),
-  #   lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 10, factor = 0.7),
-  #   early_stopping = 500L, # stop training when validation loss does not decrease anymore
-  #   validation = 0.2, # used for early stopping and lr_scheduler 
-  #   device = "cpu", 
-  #   bootstrap = 50L
-  # )
-  # 
-  # nn_fit$successfull
-  # preds = predict(nn_fit, type = "response") # --> Strange discrete steps of size 1/bootstrap_value
-  # plot(preds)
-  # Metrics::auc(train_data_nn$presence, round(preds[,1]))
+  config_lr_scheduler()
+  
+  nn_performance = tryCatch({
+    nn_fit = dnn(
+      X = train_data[, predictors],
+      Y = train_data$presence,
+      hidden = c(500L, 200L, 50L),
+      loss = "binomial",
+      activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+      epochs = 500L, 
+      lr = 0.02,   
+      baseloss=10,
+      batchsize=nrow(train_data)/4,
+      dropout = 0.1,  # Regularization 
+      optimizer = config_optimizer("adam", weight_decay = 0.001),
+      lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
+      early_stopping = 200, # stop training when validation loss does not decrease anymore
+      validation = 0.3, # used for early stopping and lr_scheduler 
+      device = "cuda",
+      bootstrap = 5
+    )
+    
+    evaluate_model(nn_fit, test_data)
+  }, error = function(e){
+    na_performance
+  })
   
   ## caret ####  
   # Define training routine
@@ -287,16 +200,18 @@ model_results = furrr::future_map(occs_split, .options = furrr::furrr_options(se
   performance_summary = tibble(
     species = species,
     obs = nrow(occs_spec),
-    model = c("RF", "GBM", "GLM"),
-    auc = c(rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC),
-    accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy),
-    kappa = c(rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa),
-    precision = c(rf_performance$Precision, gbm_performance$Precision, glm_performance$Precision),
-    recall = c(rf_performance$Recall, gbm_performance$Recall, glm_performance$Recall),
-    f1 = c(rf_performance$F1, gbm_performance$F1, glm_performance$F1)
+    model = c("NN", "RF", "GBM", "GLM"),
+    auc = c(nn_performance$AUC, rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC),
+    accuracy = c(nn_performance$Accuracy, rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy),
+    kappa = c(nn_performance$Kappa, rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa),
+    precision = c(nn_performance$Precision, rf_performance$Precision, gbm_performance$Precision, glm_performance$Precision),
+    recall = c(nn_performance$Recall, rf_performance$Recall, gbm_performance$Recall, glm_performance$Recall),
+    f1 = c(nn_performance$F1, rf_performance$F1, gbm_performance$F1, glm_performance$F1)
   )
   
   return(performance_summary)
 })
 
-save(model_results, file = "data/r_objects/model_results.RData")
+model_data_df = bind_rows(model_data)
+
+save(model_results, file = "data/r_objects/model_results.RData")
\ No newline at end of file
diff --git a/R/performance_analysis.qmd b/R/05_performance_analysis.qmd
similarity index 80%
rename from R/performance_analysis.qmd
rename to R/05_performance_analysis.qmd
index a0d098e..8941c87 100644
--- a/R/performance_analysis.qmd
+++ b/R/05_performance_analysis.qmd
@@ -11,7 +11,7 @@ library(sf)
 library(plotly)
 library(DT)
 
-load("../data/r_objects/model_results.RData")
+load("../data/r_objects/ssdm_results.RData")
 load("../data/r_objects/range_maps.RData")
 load("../data/r_objects/range_maps_gridded.RData")
 load("../data/r_objects/occs_final.RData")
@@ -21,7 +21,7 @@ sf::sf_use_s2(use_s2 = FALSE)
 
 ## Summary
 
-This document summarizes the performance of three SDM algorithms (Random Forest, Gradient Boosting Machine, Generalized Linear Model) for `{r} length(model_results)` South American mammal species. We use six metrics (AUC, F1, kappa, accuracy, precision, and recall) to evaluate model performance and look at how performance varies with five factors (number of records, range size, range coverage, range coverage bias, and functional group).
+This document summarizes the performance of three SDM algorithms (Random Forest, Gradient Boosting Machine, Generalized Linear Model, Deep Neural Network) for `{r} length(unique(ssdm_results$species))` South American mammal species. We use six metrics (AUC, F1, kappa, accuracy, precision, and recall) to evaluate model performance and look at how performance varies with five factors (number of records, range size, range coverage, range coverage bias, and functional group).
 
 Modeling decisions:
 
@@ -34,7 +34,8 @@ Modeling decisions:
 
 Key findings:
 
--   RF performed best, GBM slightly worse, GLM worst
+-   Performance: RF > GBM > GLM worst >> NN
+-   Convergence problems with Neural Notwork Models
 -   More occurrence records and larger range sizes tended to improve model performance
 -   Higher range coverage correlated with better performance.
 -   Range coverage bias and functional group showed some impact but were less consistent
@@ -44,9 +45,7 @@ Key findings:
 The table below shows the analysed modeling results.
 
 ```{r performance, echo = FALSE, message=FALSE, warnings=FALSE}
-performance = model_results %>% 
-  purrr::keep(inherits, 'data.frame') %>% 
-  bind_rows() %>% 
+performance = ssdm_results %>% 
   pivot_longer(c(auc, accuracy, kappa, precision, recall, f1), names_to = "metric") %>% 
   dplyr::filter(!is.na(value)) %>% 
   dplyr::mutate(
@@ -124,6 +123,7 @@ Range size was calculated based on polygon layers from the IUCN Red List of Thre
 -   Only RF shows continuous performance improvements beyond range sizes of \~5M km²
 
 ```{r range_size, echo = FALSE, message=FALSE, warnings=FALSE}
+
 df_join = range_maps %>% 
   dplyr::mutate(range_size = as.numeric(st_area(range_maps) / 1000000)) %>%  # range in sqkm
   sf::st_drop_geometry()
@@ -131,7 +131,27 @@ df_join = range_maps %>%
 df_plot = performance %>% 
   inner_join(df_join, by = c("species" = "name_matched"))
 
-# Create base plot
+# Function to calculate regression line with confidence intervals
+calculate_regression = function(data) {
+  # Fit log-linear model
+  model = lm(value ~ log(range_size), data = data)
+  new_x = data.frame(range_size = seq(min(data$range_size), max(data$range_size), length.out = 100))
+  # Predict using log-transformed x
+  pred = predict(model, newdata = data.frame(range_size = log(new_x$range_size)), interval = "confidence")
+  data.frame(
+    range_size = new_x$range_size,  # Keep original scale for plotting
+    fit = pred[,"fit"],
+    lower = pred[,"lwr"],
+    upper = pred[,"upr"]
+  )
+}
+
+# Calculate regression lines for each model and metric combination
+regression_lines = df_plot %>%
+  group_by(model, metric) %>%
+  group_modify(~calculate_regression(.x))
+
+# Create base scatter plot
 plot <- plot_ly(
   data = df_plot,
   x = ~range_size,
@@ -152,14 +172,80 @@ plot <- plot_ly(
   )
 )
 
-# Add dropdown for selecting metric
+# Add regression lines and confidence intervals for each model
+for (model_name in unique(df_plot$model)) {
+  plot <- plot %>%
+    add_trace(
+      data = regression_lines %>% filter(model == model_name),
+      x = ~range_size,
+      y = ~fit,
+      type = 'scatter',
+      mode = 'lines',
+      legendgroup = paste0(model_name, "regression"),
+      name = paste(model_name, '(fit)'),
+      line = list(width = 2),
+      text = NULL,
+      transforms = list(
+        list(
+          type = 'filter',
+          target = ~metric,
+          operation = '=',
+          value = 'auc'
+        )
+      )
+    ) %>%
+    add_trace(
+      data = regression_lines %>% filter(model == model_name),
+      x = ~range_size,
+      y = ~upper,
+      type = 'scatter',
+      mode = 'lines',
+      legendgroup = paste0(model_name, "regression"),
+      line = list(width = 0),
+      showlegend = FALSE,
+      name = paste(model_name, '(upper CI)'),
+      text = NULL,
+      transforms = list(
+        list(
+          type = 'filter',
+          target = ~metric,
+          operation = '=',
+          value = 'auc'
+        )
+      )
+    ) %>%
+    add_trace(
+      data = regression_lines %>% filter(model == model_name),
+      x = ~range_size,
+      y = ~lower,
+      type = 'scatter',
+      mode = 'lines',
+      legendgroup = paste0(model_name, "regression"),
+      fill = 'tonexty',
+      fillcolor = list(color = 'rgba(0,0,0,0.2)'),
+      line = list(width = 0),
+      showlegend = FALSE,
+      name = paste(model_name, '(lower CI)'),
+      text = NULL,
+      transforms = list(
+        list(
+          type = 'filter',
+          target = ~metric,
+          operation = '=',
+          value = 'auc'
+        )
+      )
+    )
+}
+
+# Add layout with dropdown
 plot <- plot %>%
   layout(
     title = "Model Performance vs. Range size",
-    xaxis = list(title = "Range size [sqkm]"),
+    xaxis = list(title = "Range size [sqkm]", type = "log"),
     yaxis = list(title = "Value"),
-    legend = list(x = 1.1, y = 0.5),  # Move legend to the right of the plot
-    margin = list(r = 150),  # Add right margin to accommodate legend
+    legend = list(x = 1.1, y = 0.5),
+    margin = list(r = 150),
     hovermode = 'closest',
     updatemenus = list(
       list(
@@ -341,7 +427,7 @@ bslib::card(plot, full_screen = T)
 Functional groups were assigned based on taxonomic order. The following groupings were used:
 
 | Functional group      | Taxomic orders                                                        |
-|-----------------------|-----------------------------------------------------------------------|
+|------------------|-----------------------------------------------------|
 | large ground-dwelling | Carnivora, Artiodactyla, Cingulata, Perissodactyla                    |
 | small ground-dwelling | Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha |
 | arboreal              | Primates, Pilosa                                                      |
diff --git a/R/install_torch.R b/R/install_torch.R
deleted file mode 100644
index 597d7de..0000000
--- a/R/install_torch.R
+++ /dev/null
@@ -1,88 +0,0 @@
-# Installation
-
-options(timeout = 600) # increasing timeout is recommended since we will be downloading a 2GB file.
-# For Windows and Linux: "cpu", "cu117" are the only currently supported
-# For MacOS the supported are: "cpu-intel" or "cpu-m1"
-kind <- "cu117"
-version <- available.packages()["torch","Version"]
-options(repos = c(
-  torch = sprintf("https://torch-cdn.mlverse.org/packages/%s/%s/", kind, version),
-  CRAN = "https://cloud.r-project.org" # or any other from which you want to install the other R dependencies.
-))
-install.packages("torch", type = "binary")
-
-
-install.packages("torchvision")
-install.packages("luz")
-
-# Load packages
-library(torch)
-library(torchvision)
-library(luz)
-
-
-# Get dataset
-
-dir <- "../playground/dataset/mnist"
-
-train_ds <- mnist_dataset(
-  dir,
-  download = TRUE,
-  transform = transform_to_tensor
-)
-
-test_ds <- mnist_dataset(
-  dir,
-  train = FALSE,
-  transform = transform_to_tensor
-)
-
-train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
-test_dl <- dataloader(test_ds, batch_size = 128)
-
-# Plot example character
-image <- train_ds$data[1,1:28,1:28]
-image_df <- data.table::melt(image)
-ggplot(image_df, aes(x=Var2, y=Var1, fill=value))+
-  geom_tile(show.legend = FALSE) + 
-  xlab("") + ylab("") +
-  scale_fill_gradient(low="white", high="black")
-
-
-# Build network
-net <- nn_module(
-  "Net",
-  initialize = function() {
-    self$conv1 <- nn_conv2d(1, 32, 3, 1)
-    self$conv2 <- nn_conv2d(32, 64, 3, 1)
-    self$dropout1 <- nn_dropout2d(0.25)
-    self$dropout2 <- nn_dropout2d(0.5)
-    self$fc1 <- nn_linear(9216, 128)
-    self$fc2 <- nn_linear(128, 10)
-  },
-  forward = function(x) {
-    x %>%                                  # N * 1 * 28 * 28
-      self$conv1() %>%                     # N * 32 * 26 * 26
-      nnf_relu() %>%                       
-      self$conv2() %>%                     # N * 64 * 24 * 24
-      nnf_relu() %>% 
-      nnf_max_pool2d(2) %>%                # N * 64 * 12 * 12
-      self$dropout1() %>% 
-      torch_flatten(start_dim = 2) %>%     # N * 9216
-      self$fc1() %>%                       # N * 128
-      nnf_relu() %>% 
-      self$dropout2() %>% 
-      self$fc2()                           # N * 10
-  }
-)
-
-# Train
-fitted <- net %>%
-  setup(
-    loss = nn_cross_entropy_loss(),
-    optimizer = optim_adam,
-    metrics = list(
-      luz_metric_accuracy()
-    )
-  ) %>%
-  fit(train_dl, epochs = 10, valid_data = test_dl)
diff --git a/R/utils.R b/R/utils.R
index cafad51..ce483b7 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -1,18 +1,31 @@
-expand_bbox <- function(bbox, expand_factor = 0.1) {
+expand_bbox <- function(bbox, min_span = 1, expansion = 0.25) {
   # Get current bbox dimensions
   x_range <- bbox["xmax"] - bbox["xmin"]
   y_range <- bbox["ymax"] - bbox["ymin"]
+  x_expand = expansion
+  y_expand = expansion
+  
+  # Make sure bbox is at least `min_diameter` wide
+  if(x_range*x_expand < min_span){
+    x_range = min_span/2
+    x_expand = 1
+  }
+  
+  if(y_range*y_expand < min_span){
+    y_range = min_span/2
+    y_expand = 1
+  }
   
   # Expand the limits, adjusting both directions correctly
-  bbox["xmin"] <- max(bbox["xmin"] - (expand_factor * x_range), -180)
-  bbox["xmax"] <- min(bbox["xmax"] + (expand_factor * x_range), 180)
-  bbox["ymin"] <- max(bbox["ymin"] - (expand_factor * y_range), -90)
-  bbox["ymax"] <- min(bbox["ymax"] + (expand_factor * y_range), 90)
+  bbox["xmin"] <- max(bbox["xmin"] - (x_expand * x_range), -180)
+  bbox["xmax"] <- min(bbox["xmax"] + (x_expand * x_range), 180)
+  bbox["ymin"] <- max(bbox["ymin"] - (y_expand * y_range), -90)
+  bbox["ymax"] <- min(bbox["ymax"] + (y_expand * y_range), 90)
   
   return(bbox)
 }
 
-evaluate_model <- function(model, test_data, threshold = 0.5) {
+evaluate_model <- function(model, test_data) {
   # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
   # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
   
@@ -26,15 +39,21 @@ evaluate_model <- function(model, test_data, threshold = 0.5) {
   # Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall)
   
   # Predict probabilities
-  probs <- predict(model, test_data, type = "prob")$present
-  preds <- predict(model, test_data, type = "raw")
-  actual <- test_data$presence_fct
+  if(class(model) %in% c("citodnn", "citodnnBootstrap")){
+    probs <- predict(model, test_data, type = "response")[,1]
+    preds <- factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
+  } else {
+    probs <- predict(model, test_data, type = "prob")$P
+    preds <- predict(model, test_data, type = "raw")
+  }
+  
+  actual <- test_data$present_fct
   
   # Calculate AUC
-  auc <- pROC::roc(actual, probs, levels = c("present", "absent"), direction = ">")$auc
+  auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
   
   # Calculate confusion matrix
-  cm <- caret::confusionMatrix(preds, actual, positive = "present")
+  cm <- caret::confusionMatrix(preds, actual, positive = "P")
   
   # Return metrics
   return(list(
-- 
GitLab