library(dplyr)
library(tidyr)
library(cito)

source("R/utils.R")

load("data/r_objects/model_data.RData")

model_data = model_data %>% 
  dplyr::filter(!is.na(fold_eval)) %>% 
  dplyr::mutate(species = as.factor(species)) %>% 
  sf::st_drop_geometry()

# ----------------------------------------------------------------------#
# Train model                                                        ####
# ----------------------------------------------------------------------#
predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness")
formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species, dim = 10)")) 

# 1. Cross validation
for(fold in 1:5){
  # Prepare data
  train_data = dplyr::filter(model_data, fold_eval != fold)
  
  # Run model
  converged = F
  while(!converged){
    msdm_embed_fit = dnn(
      formula,
      data = train_data,
      hidden = c(200L, 200L, 200L),
      loss = "binomial",
      epochs = 5000, 
      lr = 0.001,   
      batchsize = 4096,
      dropout = 0.25,
      burnin = 50,
      optimizer = config_optimizer("adam"),
      early_stopping = 200,
      validation = 0.2,
      device = "cuda"
    )
    
    if(min(msdm_embed_fit$losses$valid_l, na.rm = T) < 0.4){
      converged = T
    }
  }
  
  save(msdm_embed_fit, file = paste0("data/r_objects/msdm_embed_results/msdm_embed_fit_fold", fold,".RData"))
}

# Full model
msdm_embed_fit = dnn(
  formula,
  data = model_data,
  hidden = c(200L, 200L, 200L),
  loss = "binomial",
  epochs = 7500, 
  lr = 0.001,   
  baseloss = 1,
  batchsize = 4096,
  dropout = 0.25,
  burnin = 500,
  optimizer = config_optimizer("adam"),
  early_stopping = 300,
  validation = 0.2,
  device = "cuda"
)

save(msdm_embed_fit, file = paste0("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData"))

# ----------------------------------------------------------------------#
# Evaluate model                                                     ####
# ----------------------------------------------------------------------#
msdm_embed_performance = lapply(1:5, function(fold){
  load(paste0("data/r_objects/msdm_embed_results/msdm_embed_fit_fold", fold, ".RData"))
  
  test_data_split = model_data %>% 
    dplyr::filter(fold_eval == fold) %>% 
    dplyr::group_split(species)
  
  lapply(test_data_split, function(test_data_spec){
    species = test_data_spec$species[1]
    
    performance = tryCatch({
      evaluate_model(msdm_embed_fit, test_data_spec)
    }, error = function(e){
      list(AUC = NA_real_, Accuracy = NA_real_, Kappa = NA_real_, 
           Precision = NA_real_, Recall = NA_real_, F1 = NA_real_, 
           TP = NA_real_, FP = NA_real_, TN = NA_real_, FN = NA_real_)
    })
    
    performance_summary = performance %>% 
      as_tibble() %>% 
      mutate(
        species = !!species,
        obs = nrow(dplyr::filter(model_data, species == !!species, fold_eval != !!fold)),
        fold_eval = !!fold,
        model = "MSDM_embed",
      ) %>% 
      tidyr::pivot_longer(-any_of(c("species", "obs", "fold_eval", "model")), names_to = "metric", values_to = "value")
  }) %>% 
    bind_rows()
}) %>% 
  bind_rows()

save(msdm_embed_performance, file = paste0("data/r_objects/msdm_embed_performance.RData"))