Skip to content
Snippets Groups Projects
Select Git revision
  • 298d7d6eeb8facb9ef9929ab007e0d75756c0e13
  • main default protected
2 results

04_06_rf_testing.R

Blame
  • König's avatar
    ye87zine authored
    298d7d6e
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    04_06_rf_testing.R 4.04 KiB
    library(dplyr)
    library(tidyr)
    library(caret)
    library(ranger)
    
    source("R/utils.R")
    
    load("data/r_objects/model_data_random_abs.RData")
    
    model_data = model_data %>% 
      dplyr::filter(!is.na(fold_eval)) %>% 
      dplyr::mutate(
        species = as.factor(species),
        present_fct = factor(present, levels = c("0", "1"), labels = c("A", "P"))
      ) 
    
    predictors = c("bio6", "bio17", "cmi", "rsds", "igfc", "dtfw", "igsw", "roughness", "species") 
    
    full_data = model_data %>% 
      sf::st_drop_geometry()
    
    train_ctrl = caret::trainControl(
      search = "random",
      classProbs = TRUE, 
      summaryFunction = caret::twoClassSummary, 
      savePredictions = "final"
    )
    
    # Run model
    rf_fit = caret::train(
      x = full_data[, predictors],
      y = full_data$present_fct,
      method = "ranger",
      metric = "Accuracy",
      trControl = train_ctrl,
      tuneLength = 1,
      num.threads = 48,
      importance = 'impurity',
      verbose = T
    )
    
    save(rf_fit, file = "data/r_objects/msdm_rf/msdm_rf_fit_random_abs_full.RData")
    
    varimp = varImp(rf_fit)
    
    # Cross validation
    for(fold in 1:5){
      ## Preparations #####
      data_train = dplyr::filter(model_data, fold_eval != fold) %>% 
        sf::st_drop_geometry()
      
      # Define caret training routine 
      train_ctrl = caret::trainControl(
        method = "cv",
        number = 5,
        classProbs = TRUE, 
        summaryFunction = caret::twoClassSummary, 
        savePredictions = "final"
      )
      
      tune_grid = expand.grid(
        mtry = c(2,4,6,8),
        splitrule = "gini",
        min.node.size = c(1,4,9,16)
      )
      
      # Run model
      rf_fit = caret::train(
        x = data_train[, predictors],
        y = data_train$present_fct,
        method = "ranger",
        metric = "Accuracy",
        trControl = train_ctrl,
        tuneGrid = tune_grid,
        num.threads = 48,
        verbose = F
      )
      
      save(rf_fit, file = paste0("data/r_objects/msdm_rf/msdm_rf_fit_random_abs_fold", fold,".RData"))
    }
    
    # ----------------------------------------------------------------------#
    # Evaluate model                                                     ####
    # ----------------------------------------------------------------------#
    msdm_rf_random_abs_performance = lapply(1:5, function(fold){
      load(paste0("data/r_objects/msdm_rf/msdm_rf_fit_random_abs_fold", fold, ".RData"))
      
      test_data = dplyr::filter(model_data, fold_eval == fold) %>% 
        sf::st_drop_geometry()
      
      actual = factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
      probs = predict(rf_fit, test_data, type = "prob")$P
      preds = predict(rf_fit, test_data, type = "raw")
      
      eval_dfs = data.frame(
        species = test_data$species,
        actual,
        probs,
        preds
      ) %>% 
        group_by(species) %>% 
        group_split()
      
      
      lapply(eval_dfs, function(eval_df_spec){
        species = eval_df_spec$species[1]
        
        performance = tryCatch({
          auc = pROC::roc(eval_df_spec$actual, eval_df_spec$probs, levels = c("P", "A"), direction = ">")$auc
          cm = caret::confusionMatrix(eval_df_spec$preds, eval_df_spec$actual, positive = "P")
          
          list(
            AUC = as.numeric(auc),
            Accuracy = cm$overall["Accuracy"],
            Kappa = cm$overall["Kappa"],
            Precision = cm$byClass["Precision"],
            Recall = cm$byClass["Recall"],
            F1 = cm$byClass["F1"],
            TP = cm$table["P", "P"],
            FP = cm$table["P", "A"],
            TN = cm$table["A", "A"],
            FN = cm$table["A", "P"]
          )
        }, error = function(e){
          list(AUC = NA_real_, Accuracy = NA_real_, Kappa = NA_real_, 
               Precision = NA_real_, Recall = NA_real_, F1 = NA_real_, 
               TP = NA_real_, FP = NA_real_, TN = NA_real_, FN = NA_real_)
        })
        
        performance_summary = performance %>% 
          as_tibble() %>% 
          mutate(
            species = !!species,
            obs = nrow(dplyr::filter(model_data, species == !!species, fold_eval != !!fold)),
            fold_eval = !!fold,
            model = "MSDM_rf_random_abs",
          ) %>% 
          tidyr::pivot_longer(-any_of(c("species", "obs", "fold_eval", "model")), names_to = "metric", values_to = "value")
      }) %>% 
        bind_rows()
    }) %>% 
      bind_rows()
    
    save(msdm_rf_random_abs_performance, file = paste0("data/r_objects/msdm_rf_random_abs_performance.RData"))