library(dplyr) library(tidyr) library(cito) source("R/utils.R") load("data/r_objects/model_data.RData") model_data = model_data %>% sf::st_drop_geometry() # ----------------------------------------------------------------------# # Train model #### # ----------------------------------------------------------------------# predictors = paste0("bio", 1:19) formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "species")) # 1. Cross validation for(fold in 1:5){ # Prepare data data_train = model_data %>% dplyr::filter(record_type == "background" | fold_global != fold) # Run model msdm_onehot_fit = dnn( formula, data = data_train, hidden = c(200L, 200L, 200L), loss = "binomial", epochs = 5000, lr = 0.001, batchsize = 1024, dropout = 0.25, burnin = 50, optimizer = config_optimizer("adam"), early_stopping = 100, validation = 0.2, device = "cuda" ) save(msdm_onehot_fit, file = paste0("data/r_objects/msdm_onehot_results/msdm_onehot_fit_fold", fold,".RData")) } # Full model msdm_onehot_fit = dnn( formula, data = model_data, hidden = c(200L, 200L, 200L), loss = "binomial", epochs = 5000, lr = 0.001, batchsize = 1024, dropout = 0.25, burnin = 50, optimizer = config_optimizer("adam"), early_stopping = 100, validation = 0.2, device = "cuda" ) save(msdm_onehot_fit, file = paste0("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")) # ----------------------------------------------------------------------# # Evaluate model #### # ----------------------------------------------------------------------# msdm_onehot_performance = lapply(1:5, function(fold){ load(paste0("data/r_objects/msdm_onehot_results/msdm_onehot_fit_fold", fold, ".RData")) data_test_split = model_data %>% dplyr::filter(fold_global == fold) %>% dplyr::group_split(species) lapply(data_test_split, function(data_test_spec){ species = data_test_spec$species[1] performance = tryCatch({ evaluate_model(msdm_onehot_fit, data_test_spec) }, error = function(e){ list(auc = NA_real_, accuracy = NA_real_, kappa = NA_real_, precision = NA_real_, recall = NA_real_, f1 = NA_real_, tp = NA_real_, fp = NA_real_, tn = NA_real_, fn = NA_real_) }) performance_summary = performance %>% as_tibble() %>% mutate( species = !!species, obs = nrow(dplyr::filter(model_data, species == !!species, fold_global != !!fold)), fold_global = !!fold, model = "msdm_onehot", ) %>% tidyr::pivot_longer(-any_of(c("species", "obs", "fold_global", "model")), names_to = "metric", values_to = "value") }) %>% bind_rows() }) %>% bind_rows() save(msdm_onehot_performance, file = paste0("data/r_objects/msdm_onehot_performance.RData"))