library(dplyr)
library(tidyr)
library(cito)

source("R/utils.R")

load("data/r_objects/model_data.RData")

# ----------------------------------------------------------------------#
# Prepare data                                                       ####
# ----------------------------------------------------------------------#
train_data = dplyr::filter(model_data, present == 1, train == 1) # Use only presences for training
test_data = dplyr::filter(model_data, train == 0)                 # Evaluate on presences + pseudo-absences for comparability with binary models

predictors = paste0("layer_", 1:19)

# ----------------------------------------------------------------------#
# Train model                                                        ####
# ----------------------------------------------------------------------#
formula = as.formula(paste0("species ~ ", paste(predictors, collapse = '+')))
msdm_fit_multiclass = dnn(
  formula,
  data = train_data,
  hidden = c(200L, 200L, 200L),
  loss = "softmax",
  activation = c("leaky_relu", "leaky_relu", "leaky_relu"),
  epochs = 5000L, 
  lr = 0.01,
  batchsize = nrow(train_data)/5,
  dropout = 0.1,
  burnin = 100,
  optimizer = config_optimizer("adam", weight_decay = 0.001),
  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
  early_stopping = 250,
  validation = 0.3,
  device = "cuda"
)

save(msdm_fit_multiclass, file = "data/r_objects/msdm_fit_multiclass.RData")

# ----------------------------------------------------------------------#
# Evaluate model                                                     ####
# ----------------------------------------------------------------------#
data_split = test_data %>% 
  group_by(species) %>% 
  group_split()

msdm_results_multiclass = lapply(data_split, function(data_spec){
  msdm_performance = tryCatch({
    evaluate_multiclass_model(msdm_fit_multiclass, data_spec, k = 10) # Top-k accuracy
  }, error = function(e){
    list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
  })
  
  performance_summary = tibble(
    species = data_spec$species[1],
    obs = nrow(dplyr::filter(model_data, species == data_spec$species[1])),
    model = "MSDM_multiclass",
    auc = msdm_performance$AUC,
    accuracy = msdm_performance$Accuracy,
    kappa = msdm_performance$Kappa,
    precision = msdm_performance$Precision,
    recall = msdm_performance$Recall,
    f1 = msdm_performance$F1
  )
}) %>% bind_rows()

save(msdm_results_multiclass, file = "data/r_objects/msdm_results_multiclass.RData")