library(dplyr) library(tidyr) library(cito) source("R/utils.R") load("data/r_objects/model_data.RData") # ----------------------------------------------------------------------# # Prepare data #### # ----------------------------------------------------------------------# train_data = dplyr::filter(model_data, present == 1, train == 1) # Use only presences for training test_data = dplyr::filter(model_data, train == 0) # Evaluate on presences + pseudo-absences for comparability with binary models predictors = paste0("layer_", 1:19) # ----------------------------------------------------------------------# # Train model #### # ----------------------------------------------------------------------# formula = as.formula(paste0("species ~ ", paste(predictors, collapse = '+'))) msdm_fit_multiclass = dnn( formula, data = train_data, hidden = c(200L, 200L, 200L), loss = "softmax", activation = c("leaky_relu", "leaky_relu", "leaky_relu"), epochs = 5000L, lr = 0.01, batchsize = nrow(train_data)/5, dropout = 0.1, burnin = 100, optimizer = config_optimizer("adam", weight_decay = 0.001), lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7), early_stopping = 250, validation = 0.3, device = "cuda" ) save(msdm_fit_multiclass, file = "data/r_objects/msdm_fit_multiclass.RData") # ----------------------------------------------------------------------# # Evaluate model #### # ----------------------------------------------------------------------# data_split = test_data %>% group_by(species) %>% group_split() msdm_results_multiclass = lapply(data_split, function(data_spec){ msdm_performance = tryCatch({ evaluate_multiclass_model(msdm_fit_multiclass, data_spec, k = 10) # Top-k accuracy }, error = function(e){ list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA) }) performance_summary = tibble( species = data_spec$species[1], obs = nrow(dplyr::filter(model_data, species == data_spec$species[1])), model = "MSDM_multiclass", auc = msdm_performance$AUC, accuracy = msdm_performance$Accuracy, kappa = msdm_performance$Kappa, precision = msdm_performance$Precision, recall = msdm_performance$Recall, f1 = msdm_performance$F1 ) }) %>% bind_rows() save(msdm_results_multiclass, file = "data/r_objects/msdm_results_multiclass.RData")