Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
03_04_modelling_msdm_rf.R 4.05 KiB
library(dplyr)
library(tidyr)
library(caret)
library(ranger)

source("R/utils.R")

load("data/r_objects/model_data.RData")

model_data = model_data %>% 
  dplyr::mutate(
    present_fct = factor(present, levels = c("0", "1"), labels = c("A", "P"))
  ) %>% 
  sf::st_drop_geometry()

# ----------------------------------------------------------------------#
# Train model                                                        ####
# ----------------------------------------------------------------------#
# Define predictors
predictors = c(paste0("bio", 1:19), "species")

# Cross validation
for(fold in 1:5){
  print(paste("Fold", fold))
  ## Preparations #####
  data_train = model_data %>% 
    dplyr::filter(record_type == "background" | fold_global != fold)
  
  train_ctrl = caret::trainControl(
    search = "random",
    classProbs = TRUE, 
    number = 5,
    summaryFunction = caret::twoClassSummary, 
    savePredictions = "final"
  )
  
  # Run model
  rf_fit = caret::train(
    x = data_train[, predictors],
    y = data_train$present_fct,
    method = "ranger",
    metric = "Accuracy",
    trControl = train_ctrl,
    tuneLength = 8,
    weights = data_train$weight,
    num.threads = 48
  )
  
  save(rf_fit, file = paste0("data/r_objects/msdm_rf_results/msdm_rf_fit_fold", fold,".RData"))
}

# Full model
# Define caret training routine 
train_ctrl = caret::trainControl(
  search = "random",
  classProbs = TRUE, 
  number = 5,
  summaryFunction = caret::twoClassSummary, 
  savePredictions = "final"
)

# Run model
rf_fit = caret::train(
  x = model_data[, predictors],
  y = model_data$present_fct,
  method = "ranger",
  metric = "Accuracy",
  trControl = train_ctrl,
  tuneLength = 8,
  weights = model_data$weight,
  num.threads = 48
)

save(rf_fit, file = "data/r_objects/msdm_rf_results/msdm_rf_fit_full.RData")

# ----------------------------------------------------------------------#
# Evaluate model                                                     ####
# ----------------------------------------------------------------------#
msdm_rf_performance = lapply(1:5, function(fold){
  load(paste0("data/r_objects/msdm_rf_results/msdm_rf_fit_fold", fold, ".RData"))
  
  test_data = model_data %>% 
    dplyr::filter(fold_global == fold) %>% 
    sf::st_drop_geometry()
  
  actual = factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
  probs = predict_new(rf_fit, test_data, type = "prob")
  preds = predict_new(rf_fit, test_data, type = "class")
  
  eval_dfs = data.frame(
    species = test_data$species,
    actual,
    probs,
    preds
  ) %>% 
    group_by(species) %>% 
    group_split()
  
  
  lapply(eval_dfs, function(eval_df_spec){
    species = eval_df_spec$species[1]
    
    performance = tryCatch({
      auc = pROC::roc(eval_df_spec$actual, eval_df_spec$probs, levels = c("P", "A"), direction = ">")$auc
      cm = caret::confusionMatrix(eval_df_spec$preds, eval_df_spec$actual, positive = "P")
      
      list(
        auc = as.numeric(auc),
        accuracy = cm$overall["Accuracy"],
        kappa = cm$overall["Kappa"],
        precision = cm$byClass["Precision"],
        recall = cm$byClass["Recall"],
        f1 = cm$byClass["F1"],
        tp = cm$table["P", "P"],
        fp = cm$table["P", "A"],
        tn = cm$table["A", "A"],
        fn = cm$table["A", "P"]
      )
    }, error = function(e){
      list(auc = NA_real_, accuracy = NA_real_, kappa = NA_real_, 
           precision = NA_real_, recall = NA_real_, f1 = NA_real_, 
           tp = NA_real_, fp = NA_real_, tn = NA_real_, fn = NA_real_)
    })
    
    performance_summary = performance %>% 
      as_tibble() %>% 
      mutate(
        species = !!species,
        obs = nrow(dplyr::filter(model_data, species == !!species, fold_global != !!fold)),
        fold_global = !!fold,
        model = "msdm_rf",
      ) %>% 
      tidyr::pivot_longer(-any_of(c("species", "obs", "fold_global", "model")), names_to = "metric", values_to = "value") %>% 
      drop_na()
  }) %>% 
    bind_rows()
}) %>% 
  bind_rows()

save(msdm_rf_performance, file = paste0("data/r_objects/msdm_rf_results_performance.RData"))