diff --git a/R/04_01_modelling_ssdm.R b/R/04_01_modelling_ssdm.R index e5240db164d81d630ad0588b1e841f89f84173ce..d44330917af7440a34f9421fd4601e810a3f08ce 100644 --- a/R/04_01_modelling_ssdm.R +++ b/R/04_01_modelling_ssdm.R @@ -87,16 +87,16 @@ for(pa_spec in data_split){ }) ## Gradient Boosted Machine #### - xgb_performance = tryCatch({ - xgb_fit = train( + gbm_performance = tryCatch({ + gbm_fit = train( x = data_train[, predictors], y = data_train$present_fct, - method = "xgbTree", + method = "gbm", trControl = train_ctrl, tuneLength = 8, verbose = F ) - evaluate_model(xgb_fit, data_test) + evaluate_model(gbm_fit, data_test) }, error = function(e){ na_performance }) diff --git a/R/04_05_modelling_msdm_rf.R b/R/04_05_modelling_msdm_rf.R index 29819aa8f772b85e87abd548b7f70c3b4f55072e..5e13f48214f0b20ef85842840f3de660d49ab775 100644 --- a/R/04_05_modelling_msdm_rf.R +++ b/R/04_05_modelling_msdm_rf.R @@ -89,7 +89,7 @@ save(rf_fit, file = "data/r_objects/msdm_rf/msdm_rf_fit_full.RData") # Evaluate model #### # ----------------------------------------------------------------------# msdm_rf_performance = lapply(1:5, function(fold){ - load(paste0("data/r_objects/msdm_rf/rf_fit_fold", fold, ".RData")) + load(paste0("data/r_objects/msdm_rf/msdm_rf_fit_fold", fold, ".RData")) test_data =dplyr::filter(model_data, fold_eval == fold) %>% sf::st_drop_geometry() diff --git a/R/05_02_publication_analysis.R b/R/05_02_publication_analysis.R index ee481d3db8a09c5a9c00bd8f5f615f247e407f94..f1b9d262008a17cde8ad3251f1b54d31c6b3f037 100644 --- a/R/05_02_publication_analysis.R +++ b/R/05_02_publication_analysis.R @@ -154,7 +154,7 @@ library(cito) library(ranger) # Define plotting function -plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", "gbm", "rf", "nn", "msdm_embed", "msdm_onehot")){ +plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", "gbm", "rf", "nn", "msdm_embed", "msdm_onehot", "msdm_rf")){ # Species data load("data/r_objects/range_maps.RData") @@ -166,7 +166,7 @@ plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", # Extract raster values into df bbox_spec = sf::st_bbox(pa_spec) %>% - expand_bbox(expansion = 0.75) + expand_bbox(expansion = 0.5) raster_crop = terra::crop(raster_data, bbox_spec) new_data = raster_crop %>% @@ -185,6 +185,9 @@ plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData") probabilities = predict(msdm_embed_fit, new_data, type = "response")[,1] predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P")) + } else if(algorithm == "msdm_rf"){ + load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData") + predictions = predict(rf_fit, new_data, type = "raw") } else if(algorithm == "nn") { load(paste0("data/r_objects/ssdm_results/full_models/", spec, "_nn_fit.RData")) new_data_tmp = dplyr::select(new_data, -species) @@ -218,46 +221,89 @@ raster_data = terra::rast(raster_filepaths) %>% terra::crop(sf::st_bbox(sa_polygon)) %>% terra::project(sf::st_crs(model_data)$input) -specs = sort(sample(levels(model_data$species), 20)) +specs = sort(sample(levels(model_data$species), 30)) for(spec in specs){ - pdf(file = paste0("plots/range_predictions/", spec, ".pdf")) + pdf(file = paste0("plots/range_predictions/", spec, " (msdm).pdf")) tryCatch({ - plot_predictions(spec, model_data, raster_data, algorithms = c("gam", "xgb", "rf", "msdm_embed", "msdm_onehot")) + plot_predictions(spec, model_data, raster_data, algorithms = c("msdm_embed", "msdm_onehot", "msdm_rf")) }, finally = { dev.off() }) } # ------------------------------------------------------------------ # -# 4. Compare msdm predictions (embed vs. onehot) #### +# 4. Compare predictions across species #### # ------------------------------------------------------------------ # load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData") load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData") +load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData") -plot_embed_vs_onehot = function(){ +compare_species_predictions = function(model, sample_size){ + specs = sample(unique(model_data$species), 2, replace = F) + df1 = spatSample(raster_data, sample_size, replace = F, as.df = T) %>% + drop_na() %>% + dplyr::mutate(species = specs[1]) + df2= df1 %>% + dplyr::mutate(species = specs[2]) + + p1 = predict_new(model, df1) + p2 = predict_new(model, df2) + + plot(x = p1, y = p2, + xlab = df1$species[1], + ylab = df2$species[1], xlim=c(0,1), ylim = c(0,1), pch = 16, + main = deparse(substitute(model))) + abline(a = 0, b = 1) + text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red") + mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8) + abline(h=0.5) + abline(v=0.5) +} + +compare_species_predictions(msdm_embed_fit, 500) +compare_species_predictions(msdm_onehot_fit, 500) +compare_species_predictions(rf_fit, 500) + +## --> Predictions for different species are weakly/moderately correlated in NN models (makes sense) +## --> Predictions for different species are always highly correlated in RF model (seems problematioc) + +# ------------------------------------------------------------------ # +# 5. Compare predictions across models #### +# ------------------------------------------------------------------ # +load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData") +load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData") +load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData") + +compare_model_predictions = function(model1, model2, sample_size){ spec = sample(unique(model_data$species), 1) - new_data = spatSample(raster_data, 100, replace = F, as.df = T) %>% + new_data = spatSample(raster_data, sample_size, replace = F, as.df = T) %>% drop_na() %>% dplyr::mutate(species = spec) - p1_em = predict(msdm_embed_fit, new_data, type = "response") - p1_oh = predict(msdm_onehot_fit, new_data, type = "response") + p1 = predict_new(model1, new_data) + p2 = predict_new(model2, new_data) # Compare onehot vs embed - plot(x = p1_em, y = p1_oh, xlab = "embed", ylab = "onehot", xlim=c(0,1), ylim = c(0,1), pch = 16, main = spec) + plot(x = p1, y = p2, + xlab = deparse(substitute(model1)), + ylab = deparse(substitute(model2)), + xlim=c(0,1), ylim = c(0,1), pch = 16, main = spec) abline(a = 0, b = 1) text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red") + mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8) abline(h=0.5) abline(v=0.5) } -plot_embed_vs_onehot() +compare_model_predictions(msdm_embed_fit, msdm_onehot_fit, 500) +compare_model_predictions(msdm_embed_fit, rf_fit, 500) +compare_model_predictions(msdm_onehot_fit, rf_fit, 500) -## --> Predictions are similar between msdms -## --> Onehot model seems to predict slightly more presences than embedded model (???) +## --> Predictions for the same species from NN_embed and NN_onehot are moderately/strongly correlated (makes sense) +## --> Predictions for the same species from NN and RF are weakly/not correlated (seems problematic) # ------------------------------------------------------------------ # -# 5. Compare species embeddings across folds #### +# 6. Compare species embeddings across folds #### # ------------------------------------------------------------------ # obs_count = model_data %>% sf::st_drop_geometry() %>% @@ -359,7 +405,7 @@ ggplot(data = df_plot, aes(x = obs, y = value)) + ## --> Some potential for informed embeddings with respect to rare species? # ------------------------------------------------------------------ # -# 6. Analyse embedding of full model #### +# 7. Analyse embedding of full model #### # ------------------------------------------------------------------ # load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData") species_lookup = data.frame(species = levels(model_data$species)) %>% diff --git a/R/utils.R b/R/utils.R index b433a44f336ebd9987c8ca3671057c965b610411..1f9ec189484331ae200786eae0515028d9a1452b 100644 --- a/R/utils.R +++ b/R/utils.R @@ -25,6 +25,35 @@ expand_bbox <- function(bbox, min_span = 1, expansion = 0.25) { return(bbox) } +predict_new = function(model, data, type = "prob"){ + stopifnot(type %in% c("prob", "class")) + + if(class(model) %in% c("citodnn", "citodnnBootstrap")){ + probs = predict(model, data, type = "response")[,1] + if(type == "prob"){ + return(probs) + } else { + preds = factor(round(probs), levels = c("0", "1"), labels = c("A", "P")) + return(preds) + } + } else { + probs = predict(model, data, type = "prob")$P + if(type == "prob"){ + return(probs) + } else { + preds = predict(model, data, type = "raw") + } + } +} + +if(class(model) %in% c("citodnn", "citodnnBootstrap")){ + p1 = predict(model, df1, type = "response")[,1] + p2 = predict(model, df2, type = "response")[,1] +} else { + p1 = predict(model, df1, type = "prob")$P + p2 = predict(model, df2, type = "prob")$P +} + evaluate_model <- function(model, data) { # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances. # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN) @@ -40,7 +69,6 @@ evaluate_model <- function(model, data) { # Predict probabilities if(class(model) %in% c("citodnn", "citodnnBootstrap")){ - data = dplyr::select(data, any_of(all.vars(model$old_formula))) probs = predict(model, data, type = "response")[,1] preds = factor(round(probs), levels = c("0", "1"), labels = c("A", "P")) } else { @@ -72,56 +100,3 @@ evaluate_model <- function(model, data) { ) ) } - -evaluate_multiclass_model <- function(model, test_data, k) { - # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances. - # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN) - - # Precision: The proportion of true positives out of all instances predicted as positive. - # Formula: Precision = TP / (TP + FP) - - # Recall (Sensitivity): The proportion of true positives out of all actual positive instances. - # Formula: Recall = TP / (TP + FN) - - # F1 Score: The harmonic mean of Precision and Recall, balancing the two metrics. - # Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall) - target_species = unique(test_data$species) - checkmate::assert_character(target_species, len = 1, any.missing = F) - - new_data = dplyr::select(test_data, -species) - - # Predict probabilities - if(class(model) %in% c("citodnn", "citodnnBootstrap")){ - preds_overall = predict(model, as.matrix(new_data), type = "response") - probs <- as.vector(preds_overall[,target_species]) - - rank = apply(preds_overall, 1, function(x){ # Top-K approach - x_sort = sort(x, decreasing = T) - return(which(names(x_sort) == target_species)) - }) - top_k = as.character(as.numeric(rank <= k)) - preds <- factor(top_k, levels = c("0", "1"), labels = c("A", "P")) - } else { - stop("Unsupported model type: ", class(model)) - } - - actual <- factor(test_data$present, levels = c("0", "1"), labels = c("A", "P")) - - # Calculate AUC - auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc - - # Calculate confusion matrix - cm <- caret::confusionMatrix(preds, actual, positive = "P") - - # Return metrics - return( - list( - AUC = as.numeric(auc), - Accuracy = cm$overall["Accuracy"], - Kappa = cm$overall["Kappa"], - Precision = cm$byClass["Precision"], - Recall = cm$byClass["Recall"], - F1 = cm$byClass["F1"] - ) - ) -}