library(dplyr) library(tidyr) library(furrr) library(caret) library(cito) library(pROC) source("R/utils.R") load("data/r_objects/ssdm_data.RData") pa_split = split(ssdm_data, ssdm_data$species) # ----------------------------------------------------------------------# # Train models #### # ----------------------------------------------------------------------# future::plan("multisession", workers = 8) ssdm_results = furrr::future_map(pa_split, .options = furrr::furrr_options(seed = 123), .f = function(pa_spec){ # Initial check if(nrow(pa_spec) < 10){ return(NULL) } pa_spec$present_fct = factor(pa_spec$present, levels = c("0", "1"), labels = c("A", "P")) train_data = dplyr::filter(pa_spec, train == 1) test_data = dplyr::filter(pa_spec, train == 0) # Define empty result for performance eval na_performance = list( AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA ) # Define predictors predictors = paste0("layer_", 1:19) # Define caret training routine ##### index_train = lapply(unique(sort(train_data$fold)), function(x){ return(which(train_data$fold != x)) }) train_ctrl = trainControl( search = "grid", classProbs = TRUE, index = index_train, summaryFunction = twoClassSummary, savePredictions = "final" ) # Random Forest ##### rf_performance = tryCatch({ rf_grid = expand.grid( mtry = c(3,7,11,15,19) # Number of randomly selected predictors ) rf_fit = caret::train( x = train_data[, predictors], y = train_data$present_fct, method = "rf", metric = "ROC", tuneGrid = rf_grid, trControl = train_ctrl ) evaluate_model(rf_fit, test_data) }, error = function(e){ na_performance }) # Gradient Boosted Machine #### gbm_performance = tryCatch({ gbm_grid <- expand.grid( n.trees = c(100, 500, 1000, 1500), # number of trees interaction.depth = c(3, 5, 7), # Maximum depth of each tree shrinkage = c(0.01, 0.005, 0.001), # Lower learning rates n.minobsinnode = c(10, 20) # Minimum number of observations in nodes ) gbm_fit = train( x = train_data[, predictors], y = train_data$present_fct, method = "gbm", metric = "ROC", verbose = F, tuneGrid = gbm_grid, trControl = train_ctrl ) evaluate_model(gbm_fit, test_data) }, error = function(e){ na_performance }) # Generalized additive Model #### glm_performance = tryCatch({ glm_fit = train( x = train_data[, predictors], y = train_data$present_fct, method = "glm", family=binomial, metric = "ROC", preProcess = c("center", "scale"), trControl = train_ctrl ) evaluate_model(glm_fit, test_data) }, error = function(e){ na_performance }) # Neural Network #### nn_performance = tryCatch({ nn_fit = dnn( X = train_data[, predictors], Y = train_data$present, hidden = c(500L, 200L, 50L), loss = "binomial", activation = c("sigmoid", "leaky_relu", "leaky_relu"), epochs = 500L, lr = 0.02, baseloss=10, batchsize=nrow(train_data)/4, dropout = 0.1, # Regularization optimizer = config_optimizer("adam", weight_decay = 0.001), lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7), early_stopping = 200, # stop training when validation loss does not decrease anymore validation = 0.3, # used for early stopping and lr_scheduler device = "cuda", bootstrap = 5 ) evaluate_model(nn_fit, test_data) }, error = function(e){ na_performance }) # Summarize results performance_summary = tibble( species = pa_spec$species[1], obs = nrow(pa_spec), model = c("RF", "GBM", "GLM", "NN"), auc = c(rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC, nn_performance$AUC), accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy, nn_performance$Accuracy), kappa = c(rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa, nn_performance$Kappa), precision = c(rf_performance$Precision, gbm_performance$Precision, glm_performance$Precision, nn_performance$Precision), recall = c(rf_performance$Recall, gbm_performance$Recall, glm_performance$Recall, nn_performance$Recall), f1 = c(rf_performance$F1, gbm_performance$F1, glm_performance$F1, nn_performance$F1) ) return(performance_summary) }) ssdm_results = bind_rows(ssdm_results) save(ssdm_results, file = "data/r_objects/ssdm_results.RData")