library(dplyr)
library(tidyr)
library(cito)

source("R/utils.R")

load("data/r_objects/model_data.RData")

model_data = model_data %>% 
  sf::st_drop_geometry()

# ----------------------------------------------------------------------#
# Train model                                                        ####
# ----------------------------------------------------------------------#
predictors = paste0("bio", 1:19)
formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "species")) 

# 1. Cross validation
for(fold in 1:5){
  # Prepare data
  data_train = model_data %>% 
    dplyr::filter(record_type == "background" | fold_global != fold)
  
  # Run model
  msdm_onehot_fit = dnn(
    formula,
    data = data_train,
    hidden = c(200L, 200L, 200L),
    loss = "binomial",
    epochs = 5000, 
    lr = 0.001,   
    batchsize = 1024,
    dropout = 0.25,
    burnin = 50,
    optimizer = config_optimizer("adam"),
    early_stopping = 100,
    validation = 0.2,
    device = "cuda"
  )
  
  save(msdm_onehot_fit, file = paste0("data/r_objects/msdm_onehot_results/msdm_onehot_fit_fold", fold,".RData"))
}

# Full model
msdm_onehot_fit = dnn(
  formula,
  data = model_data,
  hidden = c(200L, 200L, 200L),
  loss = "binomial",
  epochs = 5000, 
  lr = 0.001,   
  batchsize = 1024,
  dropout = 0.25,
  burnin = 50,
  optimizer = config_optimizer("adam"),
  early_stopping = 100,
  validation = 0.2,
  device = "cuda"
)

save(msdm_onehot_fit, file = paste0("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData"))

# ----------------------------------------------------------------------#
# Evaluate model                                                     ####
# ----------------------------------------------------------------------#
msdm_onehot_performance = lapply(1:5, function(fold){
  load(paste0("data/r_objects/msdm_onehot_results/msdm_onehot_fit_fold", fold, ".RData"))
  
  data_test_split = model_data %>% 
    dplyr::filter(fold_global == fold) %>% 
    dplyr::group_split(species)
  
  lapply(data_test_split, function(data_test_spec){
    species = data_test_spec$species[1]
    
    performance = tryCatch({
      evaluate_model(msdm_onehot_fit, data_test_spec)
    }, error = function(e){
      list(auc = NA_real_, accuracy = NA_real_, kappa = NA_real_, 
           precision = NA_real_, recall = NA_real_, f1 = NA_real_, 
           tp = NA_real_, fp = NA_real_, tn = NA_real_, fn = NA_real_)
    })
    
    performance_summary = performance %>% 
      as_tibble() %>% 
      mutate(
        species = !!species,
        obs = nrow(dplyr::filter(model_data, species == !!species, fold_global != !!fold)),
        fold_global = !!fold,
        model = "msdm_onehot",
      ) %>% 
      tidyr::pivot_longer(-any_of(c("species", "obs", "fold_global", "model")), names_to = "metric", values_to = "value")
  }) %>% 
    bind_rows()
}) %>% 
  bind_rows()

save(msdm_onehot_performance, file = paste0("data/r_objects/msdm_onehot_performance.RData"))