Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
04_06_rf_testing.R 4.04 KiB
library(dplyr)
library(tidyr)
library(caret)
library(ranger)

source("R/utils.R")

load("data/r_objects/model_data_random_abs.RData")

model_data = model_data %>% 
  dplyr::filter(!is.na(fold_eval)) %>% 
  dplyr::mutate(
    species = as.factor(species),
    present_fct = factor(present, levels = c("0", "1"), labels = c("A", "P"))
  ) 

predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness", "species") 

full_data = model_data %>% 
  sf::st_drop_geometry()

train_ctrl = caret::trainControl(
  search = "random",
  classProbs = TRUE, 
  summaryFunction = caret::twoClassSummary, 
  savePredictions = "final"
)

# Run model
rf_fit = caret::train(
  x = full_data[, predictors],
  y = full_data$present_fct,
  method = "ranger",
  metric = "Accuracy",
  trControl = train_ctrl,
  tuneLength = 1,
  num.threads = 48,
  importance = 'impurity',
  verbose = T
)

save(rf_fit, file = "data/r_objects/msdm_rf/msdm_rf_fit_random_abs_full.RData")

varimp = varImp(rf_fit)

# Cross validation
for(fold in 1:5){
  ## Preparations #####
  data_train = dplyr::filter(model_data, fold_eval != fold) %>% 
    sf::st_drop_geometry()
  
  # Define caret training routine 
  train_ctrl = caret::trainControl(
    method = "cv",
    number = 5,
    classProbs = TRUE, 
    summaryFunction = caret::twoClassSummary, 
    savePredictions = "final"
  )
  
  tune_grid = expand.grid(
    mtry = c(2,4,6,8),
    splitrule = "gini",
    min.node.size = c(1,4,9,16)
  )
  
  # Run model
  rf_fit = caret::train(
    x = data_train[, predictors],
    y = data_train$present_fct,
    method = "ranger",
    metric = "Accuracy",
    trControl = train_ctrl,
    tuneGrid = tune_grid,
    num.threads = 48,
    verbose = F
  )
  
  save(rf_fit, file = paste0("data/r_objects/msdm_rf/msdm_rf_fit_random_abs_fold", fold,".RData"))
}

# ----------------------------------------------------------------------#
# Evaluate model                                                     ####
# ----------------------------------------------------------------------#
msdm_rf_random_abs_performance = lapply(1:5, function(fold){
  load(paste0("data/r_objects/msdm_rf/msdm_rf_fit_random_abs_fold", fold, ".RData"))
  
  test_data = dplyr::filter(model_data, fold_eval == fold) %>% 
    sf::st_drop_geometry()
  
  actual = factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
  probs = predict(rf_fit, test_data, type = "prob")$P
  preds = predict(rf_fit, test_data, type = "raw")
  
  eval_dfs = data.frame(
    species = test_data$species,
    actual,
    probs,
    preds
  ) %>% 
    group_by(species) %>% 
    group_split()
  
  
  lapply(eval_dfs, function(eval_df_spec){
    species = eval_df_spec$species[1]
    
    performance = tryCatch({
      auc = pROC::roc(eval_df_spec$actual, eval_df_spec$probs, levels = c("P", "A"), direction = ">")$auc
      cm = caret::confusionMatrix(eval_df_spec$preds, eval_df_spec$actual, positive = "P")
      
      list(
        AUC = as.numeric(auc),
        Accuracy = cm$overall["Accuracy"],
        Kappa = cm$overall["Kappa"],
        Precision = cm$byClass["Precision"],
        Recall = cm$byClass["Recall"],
        F1 = cm$byClass["F1"],
        TP = cm$table["P", "P"],
        FP = cm$table["P", "A"],
        TN = cm$table["A", "A"],
        FN = cm$table["A", "P"]
      )
    }, error = function(e){
      list(AUC = NA_real_, Accuracy = NA_real_, Kappa = NA_real_, 
           Precision = NA_real_, Recall = NA_real_, F1 = NA_real_, 
           TP = NA_real_, FP = NA_real_, TN = NA_real_, FN = NA_real_)
    })
    
    performance_summary = performance %>% 
      as_tibble() %>% 
      mutate(
        species = !!species,
        obs = nrow(dplyr::filter(model_data, species == !!species, fold_eval != !!fold)),
        fold_eval = !!fold,
        model = "MSDM_rf_random_abs",
      ) %>% 
      tidyr::pivot_longer(-any_of(c("species", "obs", "fold_eval", "model")), names_to = "metric", values_to = "value")
  }) %>% 
    bind_rows()
}) %>% 
  bind_rows()

save(msdm_rf_random_abs_performance, file = paste0("data/r_objects/msdm_rf_random_abs_performance.RData"))